mirror of https://github.com/llvm/torch-mlir
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243)
Like #3130, gradually replace the deprecated code https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecatedpull/3244/head
parent
466618e45e
commit
6679728c56
|
@ -23,7 +23,7 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
|
|||
int64_t dimA, int64_t dimB,
|
||||
Value &transposed) {
|
||||
Type transposedType;
|
||||
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
|
||||
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
|
||||
dimA, dimB, transposedType)))
|
||||
return failure();
|
||||
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -554,7 +554,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
// conversions which are not supported in Torch-MLIR right now.
|
||||
|
||||
Torch::ValueTensorType targetTy =
|
||||
target.getType().cast<Torch::ValueTensorType>();
|
||||
cast<Torch::ValueTensorType>(target.getType());
|
||||
if (!targetTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"target tensor must have a dtype");
|
||||
|
@ -753,9 +753,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Type listElemType =
|
||||
tensors[0]
|
||||
.getType()
|
||||
.cast<Torch::BaseTensorType>()
|
||||
cast<Torch::BaseTensorType>(tensors[0].getType())
|
||||
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||
/*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
|
@ -869,7 +867,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
|
||||
auto weightTensorType = cast<Torch::ValueTensorType>(weight.getType());
|
||||
if (!weightTensorType || !weightTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected weight type having sizes");
|
||||
|
@ -1188,7 +1186,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
|
||||
auto weightTensorType = cast<Torch::ValueTensorType>(weight.getType());
|
||||
if (!weightTensorType || !weightTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected weight type having sizes");
|
||||
|
@ -1427,7 +1425,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.customOpNameStringAttr(mode, "mode", "DCR") ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
auto inputTy = input.getType().dyn_cast<Torch::BaseTensorType>();
|
||||
auto inputTy = dyn_cast<Torch::BaseTensorType>(input.getType());
|
||||
if (!inputTy || !inputTy.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected input type having sizes");
|
||||
|
@ -1536,9 +1534,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
Value scale = operands[1];
|
||||
Value zeropoint = operands[2];
|
||||
|
||||
auto operandTy = operand.getType().cast<Torch::ValueTensorType>();
|
||||
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||
|
||||
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
|
||||
if (!scaleTy || !scaleTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
||||
if (!resultType.hasDtype())
|
||||
|
@ -1611,7 +1609,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
||||
Value trainVal = operands[2];
|
||||
auto trainTensorType =
|
||||
trainVal.getType().dyn_cast<Torch::BaseTensorType>();
|
||||
dyn_cast<Torch::BaseTensorType>(trainVal.getType());
|
||||
if (!trainTensorType)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"train tensor must have a type");
|
||||
|
@ -1629,8 +1627,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
|
||||
if (auto valueTensorLiteralOp =
|
||||
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseElementsAttr>()
|
||||
auto val = cast<DenseElementsAttr>(valueTensorLiteralOp.getValue())
|
||||
.getSplatValue<bool>();
|
||||
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
|
||||
} else {
|
||||
|
@ -2072,7 +2069,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
|
||||
SmallVector<Value> dimList;
|
||||
Torch::BaseTensorType shapeType =
|
||||
shape.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(shape.getType());
|
||||
Type selectResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({}), shapeType.getOptionalDtype());
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
|
|
@ -104,10 +104,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "operand grid_sampler bind failure");
|
||||
|
||||
auto inputTensorType = input.getType().cast<Torch::ValueTensorType>();
|
||||
auto inputTensorType = cast<Torch::ValueTensorType>(input.getType());
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
uint32_t inputRank = inputShape.size();
|
||||
auto gridTensorType = grid.getType().cast<Torch::ValueTensorType>();
|
||||
auto gridTensorType = cast<Torch::ValueTensorType>(grid.getType());
|
||||
ArrayRef<int64_t> gridShape = gridTensorType.getSizes();
|
||||
uint32_t gridRank = gridShape.size();
|
||||
|
||||
|
@ -233,7 +233,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
axis = rank + axis;
|
||||
}
|
||||
// need input type and sizes to flatten/unflatten later.
|
||||
auto inputTy = input.getType().cast<Torch::ValueTensorType>();
|
||||
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
|
||||
if (!inputTy || !inputTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "failed to get input type or sizes");
|
||||
|
@ -1065,7 +1065,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
||||
|
||||
auto transpose = [&](Value m) -> Value {
|
||||
auto tty = m.getType().cast<Torch::ValueTensorType>();
|
||||
auto tty = cast<Torch::ValueTensorType>(m.getType());
|
||||
auto shape = tty.getOptionalSizes();
|
||||
if (shape.has_value()) {
|
||||
llvm::SmallVector<int64_t> newShape(shape.value());
|
||||
|
@ -1134,7 +1134,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
|
||||
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
|
||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected input type having sizes");
|
||||
|
@ -1228,7 +1228,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
rank = *maybeRank;
|
||||
SmallVector<Value> normalized;
|
||||
axis = Torch::toPositiveDim(axis, rank);
|
||||
auto xType = x.getType().cast<Torch::ValueTensorType>();
|
||||
auto xType = cast<Torch::ValueTensorType>(x.getType());
|
||||
if (!xType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected input (X) to have sizes");
|
||||
|
@ -1307,7 +1307,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
||||
// tensor.
|
||||
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
|
||||
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
|
||||
if (!padsTensorType || !padsTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Expect non empty pad tensor");
|
||||
|
@ -1323,7 +1323,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
// As per onnx.Pad documentation, padSize = 2*num_data_axes
|
||||
// (if axes param not passed). Need to be updated when adding
|
||||
// support for `axes` param.
|
||||
auto dataOpTy = data.getType().cast<Torch::ValueTensorType>();
|
||||
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
|
||||
TensorType dataTensor = dataOpTy.toBuiltinTensor();
|
||||
if (!dataTensor || !dataTensor.hasRank())
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1350,7 +1350,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
}
|
||||
|
||||
if (!constantValue) {
|
||||
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
||||
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
||||
if (dataTensorType.getDtype().isa<IntegerType>())
|
||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
|
|
|
@ -54,7 +54,7 @@ LogicalResult reducedSumImpl(OpBinder binder,
|
|||
SmallVector<Value> axesList;
|
||||
Value axesVal;
|
||||
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: expected input and result to have shapes");
|
||||
|
@ -97,7 +97,7 @@ LogicalResult reducedSumImpl(OpBinder binder,
|
|||
}
|
||||
if (axesList.empty()) {
|
||||
Torch::BaseTensorType axesType =
|
||||
axesVal.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(axesVal.getType());
|
||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||
auto axesShape = axesTy.getSizes();
|
||||
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
||||
|
@ -177,7 +177,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value scale = operands[1];
|
||||
Value zeropoint = operands[2];
|
||||
|
||||
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
|
||||
if (!scaleTy || !scaleTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
||||
if (!resultType.hasDtype())
|
||||
|
@ -241,7 +241,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value c = operands.size() == 9 ? operands[8] : nullptr;
|
||||
|
||||
auto check = [](Value v) {
|
||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
||||
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
|
||||
};
|
||||
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
||||
|
@ -250,7 +250,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, "not supported for non per-tensor quantization");
|
||||
|
||||
auto extract = [&rewriter, &binder](Value v) {
|
||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
||||
Type extractTy = rewriter.getType<Torch::FloatType>();
|
||||
if (isa<IntegerType>(vTy.getDtype()))
|
||||
extractTy = rewriter.getType<Torch::IntType>();
|
||||
|
@ -268,7 +268,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
auto make = [&rewriter, &binder](Value v, Value scale,
|
||||
Value zp) -> Value {
|
||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
||||
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
||||
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||
binder.getLoc(), newTy, v, scale, zp);
|
||||
|
@ -351,7 +351,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value cZp = operands[7];
|
||||
|
||||
auto check = [](Value v) {
|
||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
||||
for (auto dim : vTy.getSizes())
|
||||
if (dim != 1)
|
||||
return false;
|
||||
|
@ -368,7 +368,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
rewriter.getType<Torch::IntType>()),
|
||||
ValueRange{});
|
||||
auto extract = [&rewriter, &binder, &emptyList](Value v) {
|
||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
||||
if (!vTy.getSizes().empty()) {
|
||||
vTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
|
||||
|
@ -393,7 +393,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
auto make = [&rewriter, &binder](Value v, Value scale,
|
||||
Value zp) -> Value {
|
||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
||||
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
||||
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||
binder.getLoc(), newTy, v, scale, zp);
|
||||
|
@ -667,7 +667,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
return failure();
|
||||
|
||||
Value data = inputOperands[0];
|
||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||
if (!inputType.hasSizes() || !resultType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
|
@ -718,7 +718,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
if (dimList.empty()) {
|
||||
Value axes = inputOperands[1];
|
||||
Torch::BaseTensorType axesType =
|
||||
axes.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(axes.getType());
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
|
@ -760,7 +760,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
if (binder.tensorOperands(data, axes) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||
if (!inputType.hasSizes() || !resultType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
|
@ -925,8 +925,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
// Perform an AtenToDtype op on the squared sum of the operand, stored
|
||||
// now in operand itself.
|
||||
auto size = operand.getType()
|
||||
.dyn_cast<Torch::ValueTensorType>()
|
||||
auto size = dyn_cast<Torch::ValueTensorType>(operand.getType())
|
||||
.getOptionalSizes();
|
||||
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||
size, rewriter.getF32Type());
|
||||
|
@ -1005,7 +1004,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
Value axesVal;
|
||||
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
|
@ -1053,7 +1052,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
if (axesList.empty()) {
|
||||
Torch::BaseTensorType axesType =
|
||||
axesVal.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(axesVal.getType());
|
||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||
auto axesShape = axesTy.getSizes();
|
||||
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
||||
|
@ -1191,7 +1190,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
// Extract the axes values from the axes operand:
|
||||
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||
Torch::BaseTensorType axesType =
|
||||
axes.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(axes.getType());
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
|
@ -1344,7 +1343,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
// Extract the axes values from the axes operand:
|
||||
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||
Torch::BaseTensorType axesType =
|
||||
axes.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(axes.getType());
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
|
@ -1467,12 +1466,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
auto loc = binder.getLoc();
|
||||
auto result0Ty =
|
||||
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>();
|
||||
auto resultNTy = binder.op->getResults()
|
||||
.back()
|
||||
.getType()
|
||||
.cast<Torch::ValueTensorType>();
|
||||
auto selfTy = self.getType().cast<Torch::ValueTensorType>();
|
||||
cast<Torch::ValueTensorType>(binder.op->getResult(0).getType());
|
||||
auto resultNTy = cast<Torch::ValueTensorType>(
|
||||
binder.op->getResults().back().getType());
|
||||
auto selfTy = cast<Torch::ValueTensorType>(self.getType());
|
||||
|
||||
int64_t dim = axis;
|
||||
if (dim < 0)
|
||||
|
@ -1555,7 +1552,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, "Failed to get num_outputs attribute");
|
||||
|
||||
auto result0Ty =
|
||||
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>();
|
||||
cast<Torch::ValueTensorType>(binder.op->getResult(0).getType());
|
||||
auto selfTy =
|
||||
cast<Torch::ValueTensorType>(binder.op->getOperand(0).getType());
|
||||
|
||||
|
@ -1617,7 +1614,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
auto operandType = operand.getType().cast<Torch::ValueTensorType>();
|
||||
auto operandType = cast<Torch::ValueTensorType>(operand.getType());
|
||||
TensorType tensorType = operandType.toBuiltinTensor();
|
||||
if (!tensorType || !tensorType.hasRank())
|
||||
return failure();
|
||||
|
@ -1705,26 +1702,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
}
|
||||
|
||||
auto context = rewriter.getContext();
|
||||
auto operandTorchTy = operand.getType().cast<Torch::ValueTensorType>();
|
||||
auto operandTorchTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||
auto operandTy =
|
||||
operandTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(operandTorchTy.toBuiltinTensor());
|
||||
|
||||
if (!operandTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Expected tensor operator argument to be a ranked tensor type");
|
||||
|
||||
auto startsTorchTy = starts.getType().cast<Torch::ValueTensorType>();
|
||||
auto startsTorchTy = cast<Torch::ValueTensorType>(starts.getType());
|
||||
auto startsTy =
|
||||
startsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(startsTorchTy.toBuiltinTensor());
|
||||
int startSize = startsTy.getDimSize(0);
|
||||
|
||||
auto endsTorchTy = ends.getType().cast<Torch::ValueTensorType>();
|
||||
auto endsTy =
|
||||
endsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
||||
auto endsTorchTy = cast<Torch::ValueTensorType>(ends.getType());
|
||||
auto endsTy = dyn_cast<RankedTensorType>(endsTorchTy.toBuiltinTensor());
|
||||
int endSize = endsTy.getDimSize(0);
|
||||
auto resultTy =
|
||||
resultTorchType.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(resultTorchType.toBuiltinTensor());
|
||||
if (!resultTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected result type to be a ranked tensor type");
|
||||
|
@ -1768,9 +1764,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
"and their dimensions to match");
|
||||
|
||||
if (axes) {
|
||||
auto axesTorchTy = axes.getType().cast<Torch::ValueTensorType>();
|
||||
auto axesTorchTy = cast<Torch::ValueTensorType>(axes.getType());
|
||||
auto axesTy =
|
||||
axesTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(axesTorchTy.toBuiltinTensor());
|
||||
int64_t numAxes = axesTy.getDimSize(0);
|
||||
|
||||
if (!(axesTy && numAxes == endSize))
|
||||
|
@ -1792,7 +1788,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
|
||||
auto select = [&](Value v, Value k) -> Value {
|
||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
||||
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||
loc,
|
||||
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
|
||||
|
@ -1872,7 +1868,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
}
|
||||
|
||||
Torch::BaseTensorType shapeType =
|
||||
shape.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(shape.getType());
|
||||
SmallVector<Value> dimList;
|
||||
SmallVector<int64_t> selectSizes;
|
||||
selectSizes.push_back(1);
|
||||
|
@ -2007,7 +2003,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
// instead of using the dynamic axes at operand[1].
|
||||
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||
Torch::BaseTensorType axesType =
|
||||
axes.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(axes.getType());
|
||||
auto sizes = axesType.getSizes();
|
||||
for (int i = 0; i < sizes[0]; i++) {
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -2136,7 +2132,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
// int32, int64 Assuming start, limit and delta to be same type (could
|
||||
// they be different?)
|
||||
Torch::BaseTensorType startTensorType =
|
||||
start.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(start.getType());
|
||||
bool isFloatDType = startTensorType.getDtype().isF64() ||
|
||||
startTensorType.getDtype().isF32();
|
||||
bool isIntDType = startTensorType.getDtype().isInteger(16) ||
|
||||
|
@ -2222,7 +2218,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
SmallVector<int64_t> selectSizes;
|
||||
selectSizes.push_back(1);
|
||||
Torch::BaseTensorType shapeType =
|
||||
repeatDims.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(repeatDims.getType());
|
||||
Type selectResultType = shapeType.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
|
|
@ -95,7 +95,7 @@ public:
|
|||
Value input = adaptor.getA();
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
if (!input.getType().isa<mlir::FloatType>())
|
||||
if (!isa<mlir::FloatType>(input.getType()))
|
||||
input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type());
|
||||
Value result = rewriter.create<UnaryOp>(loc, input);
|
||||
rewriter.replaceOp(op,
|
||||
|
@ -172,8 +172,8 @@ public:
|
|||
matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *context = op->getContext();
|
||||
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
|
||||
if (auto type = dyn_cast<RankedTensorType>(elements.getType())) {
|
||||
Type elemTy = op.getValueAttr().getElementType();
|
||||
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
|
||||
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
|
||||
|
@ -187,9 +187,9 @@ public:
|
|||
}
|
||||
}
|
||||
if (auto elements =
|
||||
op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
|
||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
||||
dyn_cast<DenseResourceElementsAttr>(op.getValueAttr())) {
|
||||
if (auto type = dyn_cast<RankedTensorType>(elements.getType())) {
|
||||
if (auto intType = dyn_cast<IntegerType>(type.getElementType())) {
|
||||
Type builtinTensorElemTy =
|
||||
IntegerType::get(context, intType.getIntOrFloatBitWidth());
|
||||
auto shapedType =
|
||||
|
|
|
@ -49,8 +49,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
SmallVector<Value> &strides) {
|
||||
Location loc = op.getLoc();
|
||||
auto input = adaptor.getSelf();
|
||||
RankedTensorType inputType =
|
||||
input.getType().template cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
|
@ -73,8 +72,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
Value builtinTypeStart = adaptor.getStart();
|
||||
Value builtinTypeEnd = adaptor.getEnd();
|
||||
|
||||
if (torchTypeStart.getType().isa<OptionalType>() ||
|
||||
torchTypeEnd.getType().isa<OptionalType>())
|
||||
if (isa<OptionalType>(torchTypeStart.getType()) ||
|
||||
isa<OptionalType>(torchTypeEnd.getType()))
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||
|
||||
Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep());
|
||||
|
@ -84,7 +83,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
// We cannot use to positive valid dim as for negative strides we need to
|
||||
// clamp to `-1` so that the full tensor bounds are available:
|
||||
Value end = builtinTypeEnd;
|
||||
if (torchTypeEnd.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(torchTypeEnd.getType())) {
|
||||
end = dimSize;
|
||||
} else {
|
||||
end = castIntToIndex(rewriter, loc, end);
|
||||
|
@ -594,7 +593,7 @@ public:
|
|||
int64_t endDim;
|
||||
if (!matchPattern(op.getEndDim(), m_TorchConstantInt(&endDim)))
|
||||
return rewriter.notifyMatchFailure(op, "end_dim must be constant");
|
||||
auto type = adaptor.getSelf().getType().cast<RankedTensorType>();
|
||||
auto type = cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
auto inputRank = type.getRank();
|
||||
if (inputRank == 1) {
|
||||
// If input rank is equal to 1, then there's no scope for flattening the
|
||||
|
@ -604,7 +603,7 @@ public:
|
|||
}
|
||||
|
||||
auto resultType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
if (startDim < 0)
|
||||
startDim += inputRank;
|
||||
if (endDim < 0)
|
||||
|
@ -652,7 +651,7 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType outputTensorType = cast<BaseTensorType>(op.getType());
|
||||
if (!outputTensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: output must have known sizes");
|
||||
|
@ -660,7 +659,7 @@ public:
|
|||
std::optional<unsigned> maybeRank = getTensorRank(self);
|
||||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
||||
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
|
||||
auto inputTensorType = cast<Torch::ValueTensorType>(self.getType());
|
||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Expected input type having sizes");
|
||||
|
@ -901,7 +900,7 @@ public:
|
|||
getInputAndOutputShape(Value inputTorchTensor,
|
||||
SmallVector<Value> outputSizeTorchInt) {
|
||||
SmallVector<int64_t> inputShape(
|
||||
inputTorchTensor.getType().cast<BaseTensorType>().getSizes());
|
||||
cast<BaseTensorType>(inputTorchTensor.getType()).getSizes());
|
||||
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
|
||||
for (auto [outputDim, outputDimSize] :
|
||||
llvm::enumerate(outputSizeTorchInt)) {
|
||||
|
@ -945,11 +944,11 @@ public:
|
|||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
int64_t inputRank = inputType.getRank();
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
int64_t resultRank = resultType.getRank();
|
||||
if (resultRank == 0) {
|
||||
rewriter
|
||||
|
@ -1349,7 +1348,7 @@ public:
|
|||
auto outputDims = b.create<tensor::FromElementsOp>(ty, sizes);
|
||||
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(op, resultType, self,
|
||||
outputDims);
|
||||
return success();
|
||||
|
@ -1367,13 +1366,13 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputShape = inputType.getShape();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
auto resultShape = resultType.getShape();
|
||||
int64_t resultRank = resultType.getRank();
|
||||
|
||||
|
@ -1437,7 +1436,7 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
int64_t inputRank = inputType.getRank();
|
||||
|
||||
if (inputRank == 0) {
|
||||
|
@ -1460,7 +1459,7 @@ public:
|
|||
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
int64_t resultRank = resultType.getRank();
|
||||
|
||||
// If the dim(th) dimension of operand tensor type is not statically unit,
|
||||
|
@ -1510,7 +1509,7 @@ public:
|
|||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
auto inputRank =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
dim = toPositiveDim(dim, inputRank + 1);
|
||||
if (!isValidDim(dim, inputRank + 1))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
@ -1535,9 +1534,8 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
||||
op, resultType, adaptor.getSelf(), reassociationMap);
|
||||
return success();
|
||||
|
@ -1564,11 +1562,10 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
||||
|
||||
auto inVector = adaptor.getSelf();
|
||||
auto inType = inVector.getType().cast<RankedTensorType>();
|
||||
auto inType = cast<RankedTensorType>(inVector.getType());
|
||||
auto inputRank = inType.getRank();
|
||||
auto outType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto outType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
auto elementType = inType.getElementType();
|
||||
|
||||
dim0 = toPositiveDim(dim0, inputRank);
|
||||
|
@ -1634,11 +1631,10 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "all dimensions must be constant");
|
||||
|
||||
Value inVector = adaptor.getSelf();
|
||||
auto inType = inVector.getType().cast<RankedTensorType>();
|
||||
auto inType = cast<RankedTensorType>(inVector.getType());
|
||||
int64_t inputRank = inType.getRank();
|
||||
auto outType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto outType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type elementType = inType.getElementType();
|
||||
|
||||
// Check if the dimensions are a valid constants.
|
||||
|
@ -1747,7 +1743,7 @@ public:
|
|||
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
|
||||
|
||||
RankedTensorType newResultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
int rank = newResultType.getRank();
|
||||
Value dimValue = op.getDim();
|
||||
int64_t dim;
|
||||
|
@ -1802,7 +1798,7 @@ public:
|
|||
// which in this case is `inShapeConverted` because this shape will yield
|
||||
// us the dimension size of the output.
|
||||
SmallVector<bool> useBroadcastToShape;
|
||||
int64_t inputRank = self.getType().cast<RankedTensorType>().getRank();
|
||||
int64_t inputRank = cast<RankedTensorType>(self.getType()).getRank();
|
||||
for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e;
|
||||
++i) {
|
||||
int64_t dim;
|
||||
|
@ -1821,7 +1817,7 @@ public:
|
|||
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
|
||||
rewriter, op.getLoc(), getTypeConverter(), inShape);
|
||||
auto newResultType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
Value result;
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, self, inShapeConverted, newResultType, result,
|
||||
|
@ -1869,7 +1865,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value self = adaptor.getSelf();
|
||||
Value src = adaptor.getSrc();
|
||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||
|
||||
// The non_blocking should be a constant `False`.
|
||||
bool nonBlocking;
|
||||
|
@ -1954,7 +1950,7 @@ public:
|
|||
}
|
||||
|
||||
Value src = adaptor.getSrc();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
auto srcType = cast<RankedTensorType>(src.getType());
|
||||
int64_t srcRank = srcType.getRank();
|
||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||
// TODO: audit possibility of sparsity on these tensor
|
||||
|
@ -1992,7 +1988,7 @@ public:
|
|||
auto input = adaptor.getSelf();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
auto elementType = resultType.getElementType();
|
||||
SmallVector<Value> resultShape;
|
||||
|
@ -2070,9 +2066,9 @@ public:
|
|||
auto input = adaptor.getSelf();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputElementType = getElementTypeOrSelf(input.getType());
|
||||
if (!isa<ComplexType>(inputElementType)) {
|
||||
return op.emitError("only ComplexType is allowed as input type");
|
||||
|
@ -2157,7 +2153,7 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "dim2 must be constant");
|
||||
|
||||
Value inputMatrix = adaptor.getSelf();
|
||||
RankedTensorType inputType = inputMatrix.getType().cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(inputMatrix.getType());
|
||||
int64_t inputRank = inputType.getRank();
|
||||
|
||||
if (inputRank < 2)
|
||||
|
@ -2277,7 +2273,7 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern<AtenDiagEmbedOp> {
|
|||
static SmallVector<Value>
|
||||
getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor,
|
||||
int64_t offset, int64_t dim1, int64_t dim2) {
|
||||
auto inputType = tensor.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(tensor.getType());
|
||||
auto inputRank = inputType.getRank();
|
||||
|
||||
// output tensor always has 1 extra dimension
|
||||
|
@ -2314,7 +2310,7 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputRank = inputType.getRank();
|
||||
auto resultRank = inputRank + 1;
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ public:
|
|||
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
||||
return op.emitError("unimplemented: dim is not constant");
|
||||
int64_t inputRank =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
@ -88,7 +88,7 @@ public:
|
|||
Value indices = adaptor.getIndex();
|
||||
Value self = adaptor.getSelf();
|
||||
RankedTensorType newResultTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
int64_t rank = newResultTy.getRank();
|
||||
|
||||
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
|
||||
|
@ -128,9 +128,9 @@ public:
|
|||
Value weight = adaptor.getWeight();
|
||||
Value indices = adaptor.getIndices();
|
||||
RankedTensorType newResultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
if (weightTy.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
|
||||
Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
|
||||
|
@ -140,7 +140,7 @@ public:
|
|||
sizes.push_back(embeddingDim);
|
||||
int64_t resultRank = sizes.size();
|
||||
|
||||
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
||||
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
||||
int64_t indicesRank = indicesTy.getRank();
|
||||
SmallVector<AffineExpr> indicesExprs;
|
||||
for (int i = 0; i < indicesRank; i++)
|
||||
|
@ -274,15 +274,15 @@ public:
|
|||
"include_last_offset is expected to be a constant boolean value.");
|
||||
}
|
||||
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
if (weightTy.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
|
||||
|
||||
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
||||
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
||||
if (indicesTy.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "indices must be a vector");
|
||||
|
||||
auto offsetsTy = offsets.getType().cast<RankedTensorType>();
|
||||
auto offsetsTy = cast<RankedTensorType>(offsets.getType());
|
||||
if (offsetsTy.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "offsets much be a vector");
|
||||
|
||||
|
@ -471,10 +471,9 @@ public:
|
|||
Value input = adaptor.getSelf();
|
||||
Value indices = adaptor.getIndex();
|
||||
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type elementType = resultType.getElementType();
|
||||
unsigned inputRank = inputType.getRank();
|
||||
|
||||
|
@ -604,10 +603,9 @@ public:
|
|||
op, "aten.index.Tensor: index tensor must not be None");
|
||||
}
|
||||
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type elementType = resultType.getElementType();
|
||||
int inputRank = inputType.getRank();
|
||||
int resultRank = resultType.getRank();
|
||||
|
@ -625,7 +623,7 @@ public:
|
|||
int maxRank = -1;
|
||||
for (auto indexTensor : indexTensors) {
|
||||
RankedTensorType indexTensorType =
|
||||
indexTensor.getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(indexTensor.getType());
|
||||
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
|
||||
}
|
||||
|
||||
|
@ -639,7 +637,7 @@ public:
|
|||
int64_t staticDimSize = -1;
|
||||
for (auto indexTensor : indexTensors) {
|
||||
RankedTensorType indexTensorType =
|
||||
indexTensor.getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(indexTensor.getType());
|
||||
int64_t indexTensorRank = indexTensorType.getRank();
|
||||
if ((maxRank - indexTensorRank) > (i - startIndex))
|
||||
continue;
|
||||
|
@ -714,7 +712,7 @@ public:
|
|||
|
||||
for (auto indexTensor : indexTensors) {
|
||||
RankedTensorType indexTensorType =
|
||||
indexTensor.getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(indexTensor.getType());
|
||||
auto indexTensorShape =
|
||||
makeShapeTorchCompatible(indexTensorType.getShape());
|
||||
int rank = indexTensorShape.size();
|
||||
|
@ -828,7 +826,7 @@ public:
|
|||
Value input = adaptor.getSelf();
|
||||
|
||||
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputRank = inputType.getRank();
|
||||
Type elementType = inputType.getElementType();
|
||||
|
||||
|
@ -989,7 +987,7 @@ public:
|
|||
Value gradOutput = adaptor.getGradOutput();
|
||||
|
||||
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
||||
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
|
||||
auto gradOutputType = cast<RankedTensorType>(gradOutput.getType());
|
||||
auto gradOutputRank = gradOutputType.getRank();
|
||||
Type elementType = gradOutputType.getElementType();
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
|||
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
||||
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, ValueRange{arg},
|
||||
arg.getType().cast<TensorType>().getElementType(),
|
||||
cast<TensorType>(arg.getType()).getElementType(),
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value result =
|
||||
rewriter.create<arith::AddIOp>(loc, payloadArgs[0], minSIValue);
|
||||
|
@ -58,7 +58,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
|||
|
||||
static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms,
|
||||
PatternRewriter &rewriter) {
|
||||
auto valueTy = value.getType().cast<RankedTensorType>();
|
||||
auto valueTy = cast<RankedTensorType>(value.getType());
|
||||
auto inShape = valueTy.getShape();
|
||||
llvm::SmallVector<int64_t> outShape;
|
||||
llvm::SmallVector<Value> dynDims;
|
||||
|
@ -100,8 +100,8 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
RankedTensorType lhsType = cast<RankedTensorType>(lhs.getType());
|
||||
RankedTensorType rhsType = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -109,9 +109,9 @@ public:
|
|||
}
|
||||
|
||||
ValueTensorType lhsTorchType =
|
||||
op.getSelf().getType().cast<ValueTensorType>();
|
||||
cast<ValueTensorType>(op.getSelf().getType());
|
||||
ValueTensorType rhsTorchType =
|
||||
op.getMat2().getType().cast<ValueTensorType>();
|
||||
cast<ValueTensorType>(op.getMat2().getType());
|
||||
|
||||
Value lhsZeroPoint, rhsZeroPoint;
|
||||
getZeroPoint(op.getSelf(), lhsZeroPoint);
|
||||
|
@ -148,7 +148,7 @@ public:
|
|||
"mismatching contracting dimension for torch.aten.mm"));
|
||||
}
|
||||
|
||||
auto resultTy = op.getType().cast<ValueTensorType>();
|
||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = cast<TensorType>(newResultType).getElementType();
|
||||
|
@ -176,9 +176,9 @@ public:
|
|||
|
||||
// change uint8 quantization -> int8 quantization
|
||||
int64_t numBits =
|
||||
lhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
||||
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
|
||||
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
||||
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
||||
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
|
||||
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
||||
|
||||
matmul =
|
||||
|
@ -229,9 +229,9 @@ public:
|
|||
MLIRContext *context = op.getContext();
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfRank =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
Type elementType =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
|
||||
Value c1 =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
|
||||
|
@ -299,8 +299,8 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
|
||||
return failure();
|
||||
}
|
||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
auto lhsType = cast<RankedTensorType>(lhs.getType());
|
||||
auto rhsType = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
|
||||
auto rhsTorchType = cast<ValueTensorType>(op.getOther().getType());
|
||||
|
@ -348,9 +348,9 @@ public:
|
|||
|
||||
// change uint8 quantization -> int8 quantization
|
||||
int64_t numBits =
|
||||
lhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
||||
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
|
||||
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
||||
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
||||
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
|
||||
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
||||
|
||||
// for quantized vec-vec, vec-mat, and mat-vec cases, lower to
|
||||
|
@ -726,8 +726,8 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getMat2();
|
||||
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
RankedTensorType lhsType = cast<RankedTensorType>(lhs.getType());
|
||||
RankedTensorType rhsType = cast<RankedTensorType>(rhs.getType());
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type resultElementType =
|
||||
cast<RankedTensorType>(newResultType).getElementType();
|
||||
|
@ -794,7 +794,7 @@ public:
|
|||
Value input = adaptor.getInput(); /* in form of N*C*H*W */
|
||||
Value weight = adaptor.getWeight(); /* in form of F*C*H*W */
|
||||
Value bias = adaptor.getBias();
|
||||
auto resultTy = op.getType().cast<ValueTensorType>();
|
||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||
|
||||
Value inputZp, weightZp;
|
||||
if (auto make = op.getInput()
|
||||
|
@ -826,7 +826,7 @@ public:
|
|||
}
|
||||
|
||||
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
|
||||
auto biasDTy = bias.getType().cast<RankedTensorType>().getElementType();
|
||||
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
|
||||
if (!biasDTy.isInteger(32)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "quantized result ty should be i32 accumulator");
|
||||
|
@ -838,15 +838,15 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only constant transposed supported");
|
||||
|
||||
auto inputDTy = input.getType().cast<RankedTensorType>().getElementType();
|
||||
auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType();
|
||||
auto inputDTy = cast<RankedTensorType>(input.getType()).getElementType();
|
||||
auto weightDTy = cast<RankedTensorType>(weight.getType()).getElementType();
|
||||
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
||||
|
||||
if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
|
||||
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
|
||||
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
|
||||
return op.emitError("unimplemented: non-fp not-int type");
|
||||
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
size_t inRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
size_t numSpatialDims = inRank - 2;
|
||||
if (numSpatialDims < 1 || numSpatialDims > 3)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1067,11 +1067,11 @@ public:
|
|||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
|
||||
} else {
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
||||
|
||||
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
|
||||
auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank();
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
// bias is used to initialize the channels - dimension 1 of output
|
||||
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
|
||||
|
@ -1228,9 +1228,9 @@ public:
|
|||
|
||||
// Special depthwise case
|
||||
auto inShape = makeShapeTorchCompatible(
|
||||
input.getType().cast<RankedTensorType>().getShape());
|
||||
cast<RankedTensorType>(input.getType()).getShape());
|
||||
auto weightShape = makeShapeTorchCompatible(
|
||||
weight.getType().cast<RankedTensorType>().getShape());
|
||||
cast<RankedTensorType>(weight.getType()).getShape());
|
||||
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
|
||||
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
|
||||
// Collapse weight shape
|
||||
|
@ -1264,7 +1264,7 @@ public:
|
|||
|
||||
// Grouped case, use the grouped conv linalg op
|
||||
auto expandGroups = [&](Value tensor, size_t dim) {
|
||||
auto inType = tensor.getType().cast<RankedTensorType>();
|
||||
auto inType = cast<RankedTensorType>(tensor.getType());
|
||||
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
||||
|
||||
SmallVector<int64_t> outShape;
|
||||
|
@ -1297,7 +1297,7 @@ public:
|
|||
|
||||
// expand F,C,H,W -> G,F/G,C,H,W
|
||||
auto expandWeight = [&](Value tensor) {
|
||||
auto inType = tensor.getType().cast<RankedTensorType>();
|
||||
auto inType = cast<RankedTensorType>(tensor.getType());
|
||||
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
||||
|
||||
SmallVector<int64_t> outShape{
|
||||
|
|
|
@ -80,7 +80,7 @@ computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
|
|||
SmallVectorImpl<int64_t> &dilationInts,
|
||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||
SmallVectorImpl<Value> &outTensorShape, Value initValue) {
|
||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
Value N = getDimOp(rewriter, loc, self, 0);
|
||||
|
@ -116,7 +116,7 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter,
|
|||
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
|
||||
SmallVector<int64_t> highPaddingIncludingNC = {0, 0};
|
||||
|
||||
unsigned selfRank = self.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned selfRank = cast<RankedTensorType>(self.getType()).getRank();
|
||||
unsigned paddingIntsSize = paddingInts.size();
|
||||
|
||||
if (paddingIntsSize == 2 * (selfRank - 2)) {
|
||||
|
@ -153,7 +153,7 @@ static LogicalResult createPoolingOp(
|
|||
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
|
||||
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
|
||||
Location loc = op->getLoc();
|
||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||
if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput)
|
||||
return op->emitError("unimplemented: non-floating point type");
|
||||
|
||||
|
@ -214,7 +214,7 @@ private:
|
|||
bool ceilMode) const {
|
||||
SmallVector<Value, 5> outTensorShape;
|
||||
Value self = adaptor.getSelf();
|
||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||
|
@ -307,7 +307,7 @@ public:
|
|||
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
Value self = adaptor.getSelf();
|
||||
int64_t selfRank = self.getType().cast<RankedTensorType>().getRank();
|
||||
int64_t selfRank = cast<RankedTensorType>(self.getType()).getRank();
|
||||
|
||||
if (selfRank != Dim + 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -326,7 +326,7 @@ public:
|
|||
strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||
|
||||
if constexpr (Dim == 2) {
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
|
@ -389,7 +389,7 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
Value self = adaptor.getSelf();
|
||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||
Type elementType = selfType.getElementType();
|
||||
RankedTensorType indicesRankedTensorType =
|
||||
getTypeConverter()
|
||||
|
@ -552,7 +552,7 @@ public:
|
|||
Value self = adaptor.getSelf();
|
||||
|
||||
Type inputElementType =
|
||||
self.getType().cast<RankedTensorType>().getElementType();
|
||||
cast<RankedTensorType>(self.getType()).getElementType();
|
||||
Type resultType = typeConverter->convertType(op.getType());
|
||||
Type resultElementType =
|
||||
cast<RankedTensorType>(resultType).getElementType();
|
||||
|
@ -592,8 +592,7 @@ public:
|
|||
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
||||
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
||||
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
||||
divisor =
|
||||
op.getDivisorOverride().getType().template isa<Torch::NoneType>()
|
||||
divisor = isa<Torch::NoneType>(op.getDivisorOverride().getType())
|
||||
? kHtimeskW
|
||||
: adaptor.getDivisorOverride();
|
||||
} else {
|
||||
|
@ -901,7 +900,7 @@ public:
|
|||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
const Type elementType = inputType.getElementType();
|
||||
|
||||
// get rank of input (same as rank of output)
|
||||
|
|
|
@ -127,7 +127,7 @@ public:
|
|||
Value from = adaptor.getFrom();
|
||||
Value to = adaptor.getTo();
|
||||
Value generator = adaptor.getGenerator();
|
||||
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(self.getType());
|
||||
Type elemTy = resultType.getElementType();
|
||||
Type f64Ty = rewriter.getF64Type();
|
||||
|
||||
|
|
|
@ -66,8 +66,7 @@ public:
|
|||
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
|
||||
auto idxResultType =
|
||||
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
|
||||
RankedTensorType inputType =
|
||||
input.getType().template cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
Type idxElementType =
|
||||
getElementTypeOrSelf(typec->convertType(idxResultType));
|
||||
if (!isa<IntegerType>(idxElementType))
|
||||
|
@ -472,7 +471,7 @@ private:
|
|||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
||||
typename T::Adaptor adaptor(operands);
|
||||
opInfo.tensorOperand = adaptor.getSelf();
|
||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
||||
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -480,8 +479,7 @@ private:
|
|||
|
||||
SmallVector<int64_t> dimList;
|
||||
int64_t dim;
|
||||
bool isNoneOrEmptyDimList =
|
||||
op.getDim().getType().template isa<Torch::NoneType>();
|
||||
bool isNoneOrEmptyDimList = isa<Torch::NoneType>(op.getDim().getType());
|
||||
if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
||||
// Fix negative dimensions, if any, before adding to the list.
|
||||
for (int64_t dim : dimList) {
|
||||
|
@ -522,7 +520,7 @@ private:
|
|||
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
|
||||
AtenNormScalarOp>(op)) {
|
||||
opInfo.tensorOperand = operands[0];
|
||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
||||
|
||||
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
|
||||
// `AtenMinOp` each reduce along all the dimensions of the input tensor.
|
||||
|
|
|
@ -42,7 +42,7 @@ public:
|
|||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
Value self = adaptor.getSelf();
|
||||
auto type = self.getType().cast<RankedTensorType>();
|
||||
auto type = cast<RankedTensorType>(self.getType());
|
||||
int64_t rank = type.getRank();
|
||||
|
||||
auto primList = op.getPad().getDefiningOp<Torch::PrimListConstructOp>();
|
||||
|
@ -105,7 +105,7 @@ public:
|
|||
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
||||
|
||||
Type padType = tensor::PadOp::inferResultType(
|
||||
self.getType().cast<RankedTensorType>(), staticLow, staticHigh);
|
||||
cast<RankedTensorType>(self.getType()), staticLow, staticHigh);
|
||||
Value paddedInput = rewriter.create<tensor::PadOp>(
|
||||
loc, padType, self, lowPad, highPad, castedValue);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
|
||||
|
@ -354,7 +354,7 @@ public:
|
|||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
|
||||
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -376,7 +376,7 @@ public:
|
|||
auto resultType = typeConverter->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
Type resultElementType;
|
||||
if (op.getDtype().getType().template isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||
resultElementType = resultType.getElementType();
|
||||
} else {
|
||||
int64_t dtypeInt;
|
||||
|
@ -423,7 +423,7 @@ public:
|
|||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
|
||||
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -480,7 +480,7 @@ public:
|
|||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
Type resultElementType;
|
||||
if (op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||
resultElementType = getDefaultDtypeForTorchScalar(
|
||||
|
|
|
@ -38,7 +38,7 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
Value self = adaptor.getSelf();
|
||||
Value dim = adaptor.getDim();
|
||||
auto type = self.getType().cast<RankedTensorType>();
|
||||
auto type = cast<RankedTensorType>(self.getType());
|
||||
Value inputRank = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(type.getRank()));
|
||||
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
|
||||
|
@ -86,8 +86,7 @@ public:
|
|||
Value input = adaptor.getA();
|
||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
int64_t inputRank = inputSizes.size();
|
||||
Type inputDtype =
|
||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
||||
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
|
||||
|
||||
// The `input` tensor must contain exactly one element, i.e., either the
|
||||
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
|
||||
|
|
|
@ -34,7 +34,7 @@ using namespace mlir::torch::Torch;
|
|||
|
||||
// Check if a ranked-tensor has the specified element type.
|
||||
template <typename elementType> static bool hasElementType(Value tensor) {
|
||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||
auto tensorType = cast<RankedTensorType>(tensor.getType());
|
||||
Type tensorElementType = tensorType.getElementType();
|
||||
return isa<elementType>(tensorElementType);
|
||||
}
|
||||
|
@ -173,8 +173,7 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
|
||||
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
|
||||
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
|
||||
return createLessThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
@ -200,7 +199,7 @@ template <arith::CmpIPredicate predicate>
|
|||
static LogicalResult
|
||||
createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
|
||||
Operation *op, ArrayRef<Value> operands, Value &result) {
|
||||
auto inputType = operands[0].getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(operands[0].getType());
|
||||
uint64_t inputRank = inputType.getRank();
|
||||
|
||||
// Use the indices of the two innermost dimensions.
|
||||
|
@ -405,7 +404,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
Type resultElementType =
|
||||
bitwiseAndScalar.getType().cast<BaseTensorType>().getDtype();
|
||||
cast<BaseTensorType>(bitwiseAndScalar.getType()).getDtype();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
|
@ -537,7 +536,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
if (auto relu = dyn_cast<AtenReluOp>(op)) {
|
||||
Value zeroPoint = getZeroPoint(relu.getSelf());
|
||||
Value arg = payloadArgs[0];
|
||||
auto intType = arg.getType().dyn_cast<mlir::IntegerType>();
|
||||
auto intType = dyn_cast<mlir::IntegerType>(arg.getType());
|
||||
if (zeroPoint && !intType) {
|
||||
relu.emitError("unimplemented: non-integer quantized Relu.");
|
||||
return nullptr;
|
||||
|
@ -739,9 +738,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
||||
AtenAddTensorOp::Adaptor adaptor(operands);
|
||||
Type resultElementType = add.getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = converter->convertType(add.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type resultElementType = cast<BaseTensorType>(add.getType()).getDtype();
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(add.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
|
@ -762,10 +760,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
|
||||
AtenSubTensorOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(sub.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(sub.getType()))
|
||||
.getElementType();
|
||||
Type resultElementType = sub.getType().cast<BaseTensorType>().getDtype();
|
||||
Type resultElementType = cast<BaseTensorType>(sub.getType()).getDtype();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
|
@ -785,8 +782,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
}
|
||||
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
|
||||
Type dtype = converter->convertType(subScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(subScalar.getType()))
|
||||
.getElementType();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
|
@ -805,11 +802,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) {
|
||||
Type dtype = converter->convertType(addScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(addScalar.getType()))
|
||||
.getElementType();
|
||||
Type resultElementType =
|
||||
addScalar.getType().cast<BaseTensorType>().getDtype();
|
||||
cast<BaseTensorType>(addScalar.getType()).getDtype();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
|
@ -832,8 +829,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
||||
AtenMulTensorOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(mul.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(mul.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
|
@ -846,8 +842,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
}
|
||||
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
|
||||
Type dtype = converter->convertType(atan2.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(atan2.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::FloatType>(dtype)) {
|
||||
atan2.emitError("Atan2 requires floating point result type");
|
||||
|
@ -883,8 +878,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(div.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(div.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
|
@ -907,7 +901,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
operands);
|
||||
}
|
||||
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
||||
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
|
||||
Type dtype = cast<ValueTensorType>(pow.getType()).getDtype();
|
||||
if (!isa<mlir::FloatType>(dtype)) {
|
||||
pow.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
|
@ -925,14 +919,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
pow.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Type dtype = pow.getSelf().getType().cast<ValueTensorType>().getDtype();
|
||||
Type dtype = cast<ValueTensorType>(pow.getSelf().getType()).getDtype();
|
||||
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
|
||||
}
|
||||
|
||||
if (auto pow = dyn_cast<AtenPowTensorTensorOp>(op)) {
|
||||
Type dtype = converter->convertType(pow.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(pow.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::FloatType>(dtype)) {
|
||||
pow.emitError("unimplemented: non-floating point dtype");
|
||||
|
@ -944,8 +937,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto imag = dyn_cast<AtenImagOp>(op)) {
|
||||
Type dtype = converter->convertType(imag.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(imag.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::FloatType>(dtype)) {
|
||||
imag.emitError("unimplemented: non-floating point dtype");
|
||||
|
@ -956,8 +948,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto real = dyn_cast<AtenRealOp>(op)) {
|
||||
Type dtype = converter->convertType(real.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(real.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::FloatType>(dtype)) {
|
||||
real.emitError("unimplemented: non-floating point dtype");
|
||||
|
@ -968,7 +959,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
|
||||
Type dtype = gtScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = cast<BaseTensorType>(gtScalar.getSelf().getType()).getDtype();
|
||||
|
||||
// TODO: `gtTensor` and `gtScalar` share similar code and can be called from
|
||||
// one static function.
|
||||
|
@ -998,7 +989,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
|
||||
Type dtype = geScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = cast<BaseTensorType>(geScalar.getSelf().getType()).getDtype();
|
||||
|
||||
// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
|
||||
// can be refactored.
|
||||
|
@ -1028,7 +1019,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
|
||||
Type dtype = eqScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = cast<BaseTensorType>(eqScalar.getSelf().getType()).getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
|
@ -1044,7 +1035,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
|
||||
Type dtype = neScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = cast<BaseTensorType>(neScalar.getSelf().getType()).getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
|
@ -1060,7 +1051,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
||||
Type dtype = ltScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
|
@ -1088,7 +1079,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
|
||||
Type dtype = leScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = cast<BaseTensorType>(leScalar.getSelf().getType()).getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
|
@ -1116,8 +1107,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
|
||||
Type dtype = converter->convertType(whereSelf.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(whereSelf.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
|
||||
|
@ -1141,7 +1132,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::AddFOp>(loc, start, weightedDelta);
|
||||
}
|
||||
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
||||
Type dtype = minimum.getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
|
||||
Type elemTy = converter->convertType(minimum.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
@ -1151,7 +1142,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
|
||||
}
|
||||
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
||||
Type dtype = maximum.getType().cast<BaseTensorType>().getDtype();
|
||||
Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
|
||||
Type elemTy = converter->convertType(maximum.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
@ -1170,15 +1161,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
Type dtype = converter->convertType(clamp.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(clamp.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) {
|
||||
clamp.emitError("unimplement type for clamp");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
|
||||
Type dstOriginalDtype = cast<BaseTensorType>(clamp.getType()).getDtype();
|
||||
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
|
||||
if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) {
|
||||
isUnsigned = intTy.isUnsigned();
|
||||
|
@ -1219,8 +1209,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
clampTensor.emitError("unimplemented: runtime optional type");
|
||||
return nullptr;
|
||||
}
|
||||
Type dtype = converter->convertType(clampTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(clampTensor.getType()))
|
||||
.getElementType();
|
||||
bool isMinNone = true;
|
||||
auto result = payloadArgs[0];
|
||||
|
@ -1263,8 +1253,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return result;
|
||||
}
|
||||
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
|
||||
Type dtype = converter->convertType(rsub.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(rsub.getType()))
|
||||
.getElementType();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
|
@ -1283,8 +1272,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
|
||||
Type dtype = converter->convertType(mulScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(mulScalar.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
|
@ -1297,8 +1286,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||
Value input = payloadArgs[0];
|
||||
Type dtype = converter->convertType(atenToDtype.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
|
||||
.getElementType();
|
||||
Type resultElementType;
|
||||
int64_t dtypeInt;
|
||||
|
@ -1320,8 +1309,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return result;
|
||||
}
|
||||
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
|
||||
Type dtype = converter->convertType(divScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(divScalar.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::FloatType>(dtype)) {
|
||||
divScalar.emitError("unimplemented: non-floating point dtype");
|
||||
|
@ -1395,8 +1384,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return result;
|
||||
}
|
||||
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
|
||||
Type dtype = converter->convertType(reciprocal.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(reciprocal.getType()))
|
||||
.getElementType();
|
||||
Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Type elementType = arg.getType();
|
||||
|
@ -1416,8 +1405,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
// The approach used here is as follows:
|
||||
// result = self <= threshold ? value : self
|
||||
AtenThresholdOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(thresholdOp.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(thresholdOp.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value self = payloadArgs[0];
|
||||
|
@ -1438,8 +1427,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
// The approach used here is as follows:
|
||||
// result = self <= threshold ? 0 : grad
|
||||
AtenThresholdBackwardOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(thresholdBackward.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(
|
||||
converter->convertType(thresholdBackward.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
|
@ -1459,15 +1448,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto fillScalar = dyn_cast<AtenFillScalarOp>(op)) {
|
||||
AtenFillScalarOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(fillScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(fillScalar.getType()))
|
||||
.getElementType();
|
||||
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
||||
}
|
||||
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
|
||||
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(maskedFillTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(
|
||||
converter->convertType(maskedFillTensor.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value input = payloadArgs[0];
|
||||
|
@ -1477,8 +1466,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto fillTensor = dyn_cast<AtenFillTensorOp>(op)) {
|
||||
AtenFillTensorOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(fillTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(fillTensor.getType()))
|
||||
.getElementType();
|
||||
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
}
|
||||
|
@ -1519,7 +1508,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
auto value = payloadArgs[0];
|
||||
auto valueTy = value.getType();
|
||||
auto qtensor = op->getOperand(0);
|
||||
auto qtensorTy = qtensor.getType().cast<ValueTensorType>().getDtype();
|
||||
auto qtensorTy = cast<ValueTensorType>(qtensor.getType()).getDtype();
|
||||
|
||||
Value zp, scale;
|
||||
if (auto makeQTensor =
|
||||
|
@ -1744,8 +1733,8 @@ public:
|
|||
Value ignoreIndex = adaptor.getIgnoreIndex();
|
||||
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
|
||||
|
||||
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned inputRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
unsigned targetRank = cast<RankedTensorType>(target.getType()).getRank();
|
||||
|
||||
// TODO: Add support for k-dim loss.
|
||||
if (inputRank > 2) {
|
||||
|
@ -1931,11 +1920,11 @@ public:
|
|||
failed(checkNotNone(rewriter, op, runningVar)))
|
||||
return failure();
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
auto runningMeanType = runningMean.getType().cast<RankedTensorType>();
|
||||
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto weightType = cast<RankedTensorType>(weight.getType());
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
auto runningMeanType = cast<RankedTensorType>(runningMean.getType());
|
||||
auto runningVarType = cast<RankedTensorType>(runningVar.getType());
|
||||
|
||||
auto inputRank = inputType.getRank();
|
||||
if (inputRank < 2)
|
||||
|
@ -2032,9 +2021,9 @@ public:
|
|||
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
|
||||
Value totalWeight = adaptor.getTotalWeight();
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
int inputRank = inputType.getRank();
|
||||
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
|
||||
auto gradOutputType = cast<RankedTensorType>(gradOutput.getType());
|
||||
Type resultElementType = gradOutputType.getElementType();
|
||||
|
||||
int64_t reduction;
|
||||
|
@ -2059,7 +2048,7 @@ public:
|
|||
createZeroInitTensor(rewriter, loc, outputSize, resultElementType);
|
||||
|
||||
auto getAffineMapForSingleElementTensor = [&](Value tensor) {
|
||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||
auto tensorType = cast<RankedTensorType>(tensor.getType());
|
||||
SmallVector<AffineExpr> affineExprs(tensorType.getRank(),
|
||||
rewriter.getAffineConstantExpr(0));
|
||||
return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs,
|
||||
|
@ -2188,12 +2177,12 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
auto aRankedTensorType = adaptor.getA().getType().cast<RankedTensorType>();
|
||||
auto aRankedTensorType = cast<RankedTensorType>(adaptor.getA().getType());
|
||||
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto resultRankedTensorType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
// The dimension being split must be statically known.
|
||||
|
||||
|
@ -2233,11 +2222,11 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
auto aRankedTensorType = adaptor.getA().getType().cast<RankedTensorType>();
|
||||
auto aRankedTensorType = cast<RankedTensorType>(adaptor.getA().getType());
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto resultRankedTensorType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
// Collapse range must be statically known.
|
||||
int64_t startInt;
|
||||
|
@ -2328,7 +2317,7 @@ public:
|
|||
return failure();
|
||||
}
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputElementType = inputType.getElementType();
|
||||
|
||||
if (!isa<mlir::FloatType>(inputElementType)) {
|
||||
|
@ -2433,8 +2422,8 @@ public:
|
|||
return failure();
|
||||
}
|
||||
|
||||
auto operandDTy = operand.getType().cast<ValueTensorType>().getDtype();
|
||||
auto zeropointDTy = zeropoint.getType().cast<ValueTensorType>().getDtype();
|
||||
auto operandDTy = cast<ValueTensorType>(operand.getType()).getDtype();
|
||||
auto zeropointDTy = cast<ValueTensorType>(zeropoint.getType()).getDtype();
|
||||
operand = converter->materializeTargetConversion(
|
||||
rewriter, loc, converter->convertType(operand.getType()), operand);
|
||||
scale = converter->materializeTargetConversion(
|
||||
|
@ -2537,7 +2526,7 @@ public:
|
|||
Value twoFloat = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getFloatAttr(floatType, 2.0));
|
||||
Value input = adaptor.getInput();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputShape = inputType.getShape();
|
||||
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
|
||||
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
|
||||
|
@ -2558,7 +2547,7 @@ public:
|
|||
Value innerDim1e =
|
||||
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
|
||||
Value grid = adaptor.getGrid();
|
||||
auto gridType = grid.getType().cast<RankedTensorType>();
|
||||
auto gridType = cast<RankedTensorType>(grid.getType());
|
||||
auto gridShape = gridType.getShape();
|
||||
auto gridRank = gridType.getRank();
|
||||
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);
|
||||
|
|
|
@ -37,9 +37,8 @@ Value torch_to_linalg::getPaddedTensor(
|
|||
SmallVectorImpl<int64_t> &lowPaddingInts,
|
||||
SmallVectorImpl<int64_t> &highPaddingInts, Value pad) {
|
||||
Location loc = op->getLoc();
|
||||
Type rankedTensorType =
|
||||
tensor::PadOp::inferResultType(input.getType().cast<RankedTensorType>(),
|
||||
lowPaddingInts, highPaddingInts);
|
||||
Type rankedTensorType = tensor::PadOp::inferResultType(
|
||||
cast<RankedTensorType>(input.getType()), lowPaddingInts, highPaddingInts);
|
||||
SmallVector<OpFoldResult> lowPaddings =
|
||||
getIndexIntsAsOpFoldResult(b, lowPaddingInts);
|
||||
SmallVector<OpFoldResult> highPaddings =
|
||||
|
@ -61,7 +60,7 @@ Value torch_to_linalg::getZeroPaddedTensor(
|
|||
Location loc = op->getLoc();
|
||||
Value c0 = b.create<arith::ConstantOp>(
|
||||
loc,
|
||||
b.getZeroAttr(input.getType().cast<RankedTensorType>().getElementType()));
|
||||
b.getZeroAttr(cast<RankedTensorType>(input.getType()).getElementType()));
|
||||
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
|
||||
}
|
||||
|
||||
|
@ -73,7 +72,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
|||
int unpaddedDims, Value pad) {
|
||||
assert(input.getType().isa<RankedTensorType>() &&
|
||||
"input must be RankedTensorType");
|
||||
unsigned int inRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
SmallVector<Value> inputDims = getTensorSizes(b, loc, input);
|
||||
|
@ -86,7 +85,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
|||
pad < paddingIncludingUnchanged.end(); pad++)
|
||||
*pad = castIntToIndex(b, loc, *pad);
|
||||
|
||||
Type elementType = input.getType().cast<RankedTensorType>().getElementType();
|
||||
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
|
||||
// TODO: audit possibility of sparsity on this tensor
|
||||
Type inputType =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
|
||||
|
@ -158,7 +157,7 @@ Value torch_to_linalg::getOutputDimForConvTransposeOps(
|
|||
Value torch_to_linalg::createReductionLinalgGeneric(
|
||||
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
||||
|
||||
// Get the result shape by obtaining the size of each
|
||||
// dimension in the input tensor that is not getting reduced.
|
||||
|
@ -237,7 +236,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
|
|||
SmallVector<int64_t> operandRanks;
|
||||
operandRanks.resize(tensorOperands.size());
|
||||
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
|
||||
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
|
||||
return dyn_cast<RankedTensorType>(tensor.getType()).getRank();
|
||||
});
|
||||
|
||||
auto resultRankIt =
|
||||
|
@ -253,7 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
|
|||
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b);
|
||||
for (Value tensorOperand : tensorOperands) {
|
||||
SmallVector<AffineExpr> exprs;
|
||||
auto type = tensorOperand.getType().cast<RankedTensorType>();
|
||||
auto type = cast<RankedTensorType>(tensorOperand.getType());
|
||||
for (auto size :
|
||||
llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) {
|
||||
// If the size is statically known to be 1, we don't want any
|
||||
|
@ -327,7 +326,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
|
|||
Operation *op, PatternRewriter &rewriter, Value input,
|
||||
SmallVector<Value> broadcastToShape, RankedTensorType broadcastType,
|
||||
Value &result, SmallVector<bool> useBroadcastToShape) {
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
int64_t inputRank = inputType.getRank();
|
||||
int64_t outputRank = broadcastToShape.size();
|
||||
ArrayRef<int64_t> outputShape = broadcastType.getShape();
|
||||
|
@ -525,7 +524,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
|
|||
|
||||
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
|
||||
Value tensor) {
|
||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||
auto tensorType = cast<RankedTensorType>(tensor.getType());
|
||||
auto rank = tensorType.getRank();
|
||||
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
|
||||
return b.create<tensor::CastOp>(
|
||||
|
|
|
@ -66,8 +66,8 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
|
|||
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||
mlir::Value &self, mlir::Value &other,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto selfTy = self.getType().template dyn_cast<RankedTensorType>();
|
||||
auto otherTy = other.getType().template dyn_cast<RankedTensorType>();
|
||||
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||
auto otherTy = dyn_cast<RankedTensorType>(other.getType());
|
||||
auto selfRank = selfTy.getRank();
|
||||
auto otherRank = otherTy.getRank();
|
||||
if (selfRank == 0 || otherRank == 0)
|
||||
|
@ -171,7 +171,7 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfType = self.getType().cast<TensorType>();
|
||||
auto selfType = cast<TensorType>(self.getType());
|
||||
if (!selfType) {
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
}
|
||||
|
@ -197,12 +197,12 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
|
@ -229,14 +229,14 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
if (resultTy.getElementType().template isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
||||
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
||||
return success();
|
||||
|
@ -304,8 +304,7 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto inputType =
|
||||
adaptor.getA().getType().template dyn_cast<RankedTensorType>();
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
|
||||
if (!inputType)
|
||||
|
||||
op.emitError("only Tensor types supported in StableHLO");
|
||||
|
@ -313,8 +312,7 @@ public:
|
|||
Value input = adaptor.getA();
|
||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
int64_t inputRank = inputSizes.size();
|
||||
Type inputDtype =
|
||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
||||
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
|
||||
|
||||
Value constantOne =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
|
@ -345,9 +343,9 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
auto lhsTy = cast<TensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getOther();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
auto rhsTy = cast<TensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("only Tensor types supported");
|
||||
|
@ -378,9 +376,9 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getOther();
|
||||
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType rhsType = dyn_cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
@ -433,9 +431,9 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getOther();
|
||||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
@ -527,8 +525,8 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
|
||||
RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
@ -616,8 +614,8 @@ public:
|
|||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
|
||||
RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("lhs must be a ranked tensor type");
|
||||
|
@ -659,11 +657,10 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
||||
}
|
||||
|
||||
auto inType = self.getType().cast<RankedTensorType>();
|
||||
auto inType = cast<RankedTensorType>(self.getType());
|
||||
auto inputRank = inType.getRank();
|
||||
auto outType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto outType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
|
||||
dim0 = toPositiveDim(dim0, inputRank);
|
||||
if (!isValidDim(dim0, inputRank)) {
|
||||
|
@ -691,7 +688,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
|
||||
return success();
|
||||
}
|
||||
|
@ -701,7 +698,7 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
|||
AtenSizeIntOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
|
||||
|
@ -739,7 +736,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
Value other = adaptor.getOther();
|
||||
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
// promote self and other types
|
||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
|
||||
|
@ -764,10 +761,9 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto outType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
auto outType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
|
||||
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
||||
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
|
@ -831,10 +827,9 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
// Not a ranked tensor type
|
||||
auto inType = self.getType().dyn_cast<RankedTensorType>();
|
||||
auto outType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto inType = dyn_cast<RankedTensorType>(self.getType());
|
||||
auto outType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
if (!inType)
|
||||
return op.emitError("only ranked tensor types with static shapes are "
|
||||
"currently supported");
|
||||
|
@ -861,15 +856,14 @@ template <>
|
|||
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
|
||||
// Tensors with integer types need to be converted to signless integer
|
||||
// element type. All tensors with element types other than integer can reuse
|
||||
// existing elements attribute.
|
||||
// TODO: what about unsigned integer?
|
||||
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
|
||||
Type builtinTensorElemTy = resultType.getElementType();
|
||||
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
||||
|
||||
|
@ -892,9 +886,8 @@ template <>
|
|||
LogicalResult ConvertAtenOp<AtenTensorIntOp>::matchAndRewrite(
|
||||
AtenTensorIntOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type outElementType = resultType.getElementType();
|
||||
Value innerValue = adaptor.getT();
|
||||
Value stablehloTensor =
|
||||
|
@ -910,10 +903,10 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
|||
AtenReciprocalOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
if (!inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
if (!isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||
return op.emitError("only floating-point datatype legalization supported "
|
||||
"for AtenReciprocalOp");
|
||||
}
|
||||
|
@ -929,9 +922,9 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getExponent();
|
||||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
@ -1002,9 +995,8 @@ template <>
|
|||
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
RankedTensorType outputType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType outputType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
auto outputElemType = outputType.getElementType();
|
||||
Value stablehloTensor = hlo::scalarToStablehloTensor(
|
||||
rewriter, op, adaptor.getA(), outputElemType);
|
||||
|
@ -1018,8 +1010,7 @@ LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
|
|||
AtenScalarImplicitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Type inputDtype =
|
||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
||||
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
|
||||
|
@ -1037,7 +1028,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
|
||||
|
@ -1055,7 +1046,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
AtenReluOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
|
||||
if (!isa<mlir::FloatType>(lhsElemTy)) {
|
||||
|
@ -1080,7 +1071,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return op.emitError("only ranked tensor type is supported.");
|
||||
}
|
||||
|
@ -1103,11 +1094,11 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
|||
AtenLog2Op op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return op.emitError("only ranked tensor type is supported.");
|
||||
}
|
||||
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
|
||||
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
||||
|
@ -1124,12 +1115,12 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
|
|||
AtenLog10Op op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return op.emitError("only ranked tensor type is supported.");
|
||||
}
|
||||
|
||||
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
|
||||
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
||||
|
@ -1146,8 +1137,8 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
|||
AtenErfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().cast<TensorType>();
|
||||
if (!inputType.getElementType().isa<mlir::FloatType>()) {
|
||||
auto inputType = cast<TensorType>(input.getType());
|
||||
if (!isa<mlir::FloatType>(inputType.getElementType())) {
|
||||
return rewriter.notifyMatchFailure(op, "only float tensor is supported");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<chlo::ErfOp>(
|
||||
|
@ -1161,7 +1152,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
AtenBatchNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getInput();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
Value weight = adaptor.getWeight();
|
||||
Value bias = adaptor.getBias();
|
||||
Value runningMean = adaptor.getRunningMean();
|
||||
|
@ -1174,10 +1165,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
// all of NC, NCL, NCHW, NCDHW's feature index is 1.
|
||||
int64_t feature_index = 1;
|
||||
|
||||
if (!inputTy.getElementType().template isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||
return op.emitError("only input tensor of float type is supported");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType().cast<mlir::FloatType>();
|
||||
auto inputElemTy = cast<mlir::FloatType>(inputTy.getElementType());
|
||||
|
||||
Value channelDim =
|
||||
rewriter.create<tensor::DimOp>(op->getLoc(), input, feature_index);
|
||||
|
@ -1220,20 +1211,20 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
inputTy.getElementType()));
|
||||
}
|
||||
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto biasTy = bias.getType().cast<RankedTensorType>();
|
||||
auto runningMeanTy = runningMean.getType().cast<RankedTensorType>();
|
||||
auto runningVarTy = runningVar.getType().cast<RankedTensorType>();
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto biasTy = cast<RankedTensorType>(bias.getType());
|
||||
auto runningMeanTy = cast<RankedTensorType>(runningMean.getType());
|
||||
auto runningVarTy = cast<RankedTensorType>(runningVar.getType());
|
||||
|
||||
if (weightTy.getRank() != 1 || biasTy.getRank() != 1 ||
|
||||
runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expect weight, bias, running_mean and running_var to be rank 1");
|
||||
}
|
||||
if (!weightTy.getElementType().template isa<mlir::FloatType>() ||
|
||||
!biasTy.getElementType().template isa<mlir::FloatType>() ||
|
||||
!runningMeanTy.getElementType().template isa<mlir::FloatType>() ||
|
||||
!runningVarTy.getElementType().template isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(weightTy.getElementType()) ||
|
||||
!isa<mlir::FloatType>(biasTy.getElementType()) ||
|
||||
!isa<mlir::FloatType>(runningMeanTy.getElementType()) ||
|
||||
!isa<mlir::FloatType>(runningVarTy.getElementType())) {
|
||||
return op.emitError("only float weight/bias/runningMean/runningVar tensor "
|
||||
"of float type is supported");
|
||||
}
|
||||
|
@ -1261,8 +1252,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||
RankedTensorType convertedType = inputTy;
|
||||
if (weightTy.getElementType().cast<FloatType>().getWidth() >
|
||||
inputTy.getElementType().cast<FloatType>().getWidth()) {
|
||||
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
||||
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
||||
weightTy.getElementType());
|
||||
}
|
||||
|
@ -1302,8 +1293,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||
RankedTensorType convertedType = inputTy;
|
||||
if (weightTy.getElementType().cast<FloatType>().getWidth() >
|
||||
inputTy.getElementType().cast<FloatType>().getWidth()) {
|
||||
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
||||
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
||||
weightTy.getElementType());
|
||||
}
|
||||
|
@ -1340,7 +1331,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getInput();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
Value weight = adaptor.getWeight();
|
||||
|
@ -1365,12 +1356,12 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
failed(checkNotNone(rewriter, op, bias))) {
|
||||
return op->emitError("none weight or bias is unsupported");
|
||||
}
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto biasTy = bias.getType().cast<RankedTensorType>();
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto biasTy = cast<RankedTensorType>(bias.getType());
|
||||
|
||||
if (!inputTy.getElementType().isa<mlir::FloatType>() ||
|
||||
!biasTy.getElementType().isa<mlir::FloatType>() ||
|
||||
!weightTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(inputTy.getElementType()) ||
|
||||
!isa<mlir::FloatType>(biasTy.getElementType()) ||
|
||||
!isa<mlir::FloatType>(weightTy.getElementType())) {
|
||||
return op->emitError("currently only float data type are supported");
|
||||
}
|
||||
int64_t normalizedShapeRank = normalizedShape.size();
|
||||
|
@ -1423,7 +1414,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
SmallVector<APFloat> oneConstVec(
|
||||
numFeatureDimSize,
|
||||
APFloat(
|
||||
inputTy.getElementType().cast<mlir::FloatType>().getFloatSemantics(),
|
||||
cast<mlir::FloatType>(inputTy.getElementType()).getFloatSemantics(),
|
||||
1));
|
||||
auto oneOrZeroConstType =
|
||||
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
|
||||
|
@ -1443,9 +1434,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
|
||||
// Reshape back
|
||||
auto outputTy =
|
||||
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
|
||||
auto outputMeanOrVarTy =
|
||||
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||
|
||||
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
||||
|
@ -1482,7 +1473,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
AtenCatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -1516,7 +1507,7 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
|||
AtenNumelOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
|
||||
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||
size_t rank = selfTy.getRank();
|
||||
|
||||
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
|
@ -1544,7 +1535,7 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
AtenClampOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemType = inputType.getElementType();
|
||||
Value minValue = adaptor.getMin();
|
||||
Value maxValue = adaptor.getMax();
|
||||
|
@ -1716,7 +1707,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto outType =
|
||||
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||
if (!outType) {
|
||||
return op.emitError("only tensor type is supported");
|
||||
}
|
||||
|
@ -1764,15 +1755,15 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
|
|||
AtenPowTensorTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
auto lhsTy = cast<TensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getExponent();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
auto rhsTy = cast<TensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("only Tensor types supported");
|
||||
|
||||
auto outTy =
|
||||
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
||||
|
@ -1790,12 +1781,12 @@ LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
|||
Value generator = adaptor.getGenerator();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
if (!generator.getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(generator.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
auto elements = self.getType().cast<RankedTensorType>().getShape();
|
||||
auto elements = cast<RankedTensorType>(self.getType()).getShape();
|
||||
if (llvm::any_of(elements,
|
||||
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
||||
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
||||
|
@ -1824,14 +1815,14 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
|||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
|
||||
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: pin_memory must be either None or false");
|
||||
|
||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
||||
int64_t memoryFormat;
|
||||
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1844,7 +1835,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
|||
"memory_format is supported");
|
||||
}
|
||||
|
||||
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
|
||||
std::string device;
|
||||
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1853,7 +1844,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
|||
|
||||
// TODO: Add support for non-strided layout.
|
||||
// torch.layout is by default strided i.e. 0.
|
||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||
int64_t tensorLayout;
|
||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1876,9 +1867,9 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
|||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
Type resultElementType;
|
||||
if (op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||
resultElementType = resultType.getElementType();
|
||||
} else {
|
||||
int64_t dtypeInt;
|
||||
|
@ -1931,7 +1922,7 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
|||
AtenFillScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto dtype = outType.getElementType();
|
||||
Value scalarTensor =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype);
|
||||
|
@ -1951,7 +1942,7 @@ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) {
|
||||
|
|
|
@ -64,7 +64,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
|||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
|
||||
// sliceSizes
|
||||
auto inputRankTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputRankTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto inputRank = inputRankTy.getRank();
|
||||
SmallVector<Value, 4> sliceSizes;
|
||||
sliceSizes.reserve(inputRank);
|
||||
|
@ -85,7 +85,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
|||
for (int64_t r = 0; r < axis; ++r) {
|
||||
offsetDims.push_back(r);
|
||||
}
|
||||
auto indicesRankTy = indices.getType().dyn_cast<RankedTensorType>();
|
||||
auto indicesRankTy = dyn_cast<RankedTensorType>(indices.getType());
|
||||
auto indicesRank = indicesRankTy.getRank();
|
||||
for (int64_t r = axis + 1; r < inputRank; ++r) {
|
||||
offsetDims.push_back(r + indicesRank - 1);
|
||||
|
@ -132,8 +132,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
SmallVector<Value> &strides) {
|
||||
Location loc = op.getLoc();
|
||||
auto input = adaptor.getSelf();
|
||||
RankedTensorType inputType =
|
||||
input.getType().template cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
|
@ -161,7 +160,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
|
||||
int64_t step;
|
||||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||
if (!op.getStep().getType().template isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getStep().getType()))
|
||||
return op->emitError("unimplemented: step is not constant");
|
||||
step = 1;
|
||||
}
|
||||
|
@ -225,7 +224,7 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
|||
// concat index tensor into to indices tensor for concat
|
||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
auto indexTensor = indexTensors[i];
|
||||
auto indexTensorType = indexTensor.getType().cast<RankedTensorType>();
|
||||
auto indexTensorType = cast<RankedTensorType>(indexTensor.getType());
|
||||
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
|
||||
if (size == kUnknownSize)
|
||||
return failure();
|
||||
|
@ -249,7 +248,7 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
|||
|
||||
SmallVector<Value> broadcastedIndices;
|
||||
Type indexElemTy =
|
||||
indexTensors[0].getType().cast<RankedTensorType>().getElementType();
|
||||
cast<RankedTensorType>(indexTensors[0].getType()).getElementType();
|
||||
RankedTensorType bcastIndexType =
|
||||
RankedTensorType::get(indicesShape, indexElemTy);
|
||||
for (auto indexTensor : indexTensors) {
|
||||
|
@ -290,7 +289,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|||
AtenEmbeddingOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto weight = adaptor.getWeight();
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
if (!weightTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
|
||||
|
@ -332,17 +331,17 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
|||
Value indices = adaptor.getIndices();
|
||||
Value offsets = adaptor.getOffsets();
|
||||
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
if (weightTy && weightTy.hasStaticShape() && weightTy.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "weight must be rank 2 tensor with static shapes");
|
||||
|
||||
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
||||
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
||||
if (indicesTy && indicesTy.hasStaticShape() && indicesTy.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "indices must be a vector with static shapes");
|
||||
|
||||
auto offsetsTy = offsets.getType().cast<RankedTensorType>();
|
||||
auto offsetsTy = cast<RankedTensorType>(offsets.getType());
|
||||
if (offsetsTy && offsetsTy.getRank() != 1 && offsetsTy.hasStaticShape() &&
|
||||
offsetsTy.getShape()[0] == 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -485,7 +484,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
AtenIndexSelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
int64_t dim;
|
||||
|
@ -514,8 +513,8 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
Location loc = op->getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
Value index = adaptor.getIndex();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto indexType = index.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto indexType = cast<RankedTensorType>(index.getType());
|
||||
auto indexElemType = indexType.getElementType();
|
||||
|
||||
if (indexType.getRank() != inputType.getRank()) {
|
||||
|
@ -623,7 +622,7 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
Value src = adaptor.getSrc();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
auto srcType = cast<RankedTensorType>(src.getType());
|
||||
int64_t srcRank = srcType.getRank();
|
||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||
auto abstractSrcType = RankedTensorType::get(
|
||||
|
@ -651,9 +650,9 @@ public:
|
|||
Value input = adaptor.getSelf();
|
||||
Value index = adaptor.getIndex();
|
||||
Value src = adaptor.getSrc();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto indexType = index.getType().cast<RankedTensorType>();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto indexType = cast<RankedTensorType>(index.getType());
|
||||
auto srcType = cast<RankedTensorType>(src.getType());
|
||||
auto indexElemType = indexType.getElementType();
|
||||
|
||||
if (indexType.getRank() != inputType.getRank() ||
|
||||
|
@ -789,9 +788,9 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTensorType = input.getType().cast<RankedTensorType>();
|
||||
auto inputTensorType = cast<RankedTensorType>(input.getType());
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto outShape = outType.getShape();
|
||||
Value indexList = op.getIndices();
|
||||
SmallVector<Value> indicesTorchType;
|
||||
|
@ -857,10 +856,10 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
Value input = adaptor.getSelf();
|
||||
Value values = adaptor.getValues();
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
int64_t inputRank = inputType.getRank();
|
||||
auto valuesType = values.getType().cast<RankedTensorType>();
|
||||
auto valuesType = cast<RankedTensorType>(values.getType());
|
||||
auto valuesShape = valuesType.getShape();
|
||||
bool accumulate;
|
||||
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace {
|
|||
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||
ArrayRef<int64_t> shape, ArrayRef<Value> dimSizes,
|
||||
ArrayRef<int64_t> broadcastDims) {
|
||||
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
auto tensorTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto loc = op->getLoc();
|
||||
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
|
||||
|
@ -48,7 +48,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
|||
|
||||
Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
||||
ArrayRef<int64_t> inpTransDims) {
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto rank = inputTy.getRank();
|
||||
auto transDims = hlo::toPositiveDims(inpTransDims, rank);
|
||||
auto inpShape = inputTy.getShape();
|
||||
|
@ -70,8 +70,8 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
|
|||
int64_t lhsResultDim, int64_t rhsResultDim,
|
||||
int64_t lhsContractingDim,
|
||||
int64_t rhsContractingDim) {
|
||||
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
|
||||
auto rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
auto oldLhsShape = lhsTy.getShape();
|
||||
auto oldRhsShape = rhsTy.getShape();
|
||||
|
@ -129,8 +129,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
size_t dimSizeIndexBits) {
|
||||
Value lhs = inpLhs;
|
||||
Value rhs = inpRhs;
|
||||
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhsRankTy = inpRhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto lhsRankTy = dyn_cast<RankedTensorType>(inpLhs.getType());
|
||||
auto rhsRankTy = dyn_cast<RankedTensorType>(inpRhs.getType());
|
||||
|
||||
auto lhsRank = lhsRankTy.getRank();
|
||||
auto rhsRank = rhsRankTy.getRank();
|
||||
|
@ -177,8 +177,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
return;
|
||||
}
|
||||
|
||||
lhsShape = lhs.getType().cast<RankedTensorType>().getShape();
|
||||
rhsShape = rhs.getType().cast<RankedTensorType>().getShape();
|
||||
lhsShape = cast<RankedTensorType>(lhs.getType()).getShape();
|
||||
rhsShape = cast<RankedTensorType>(rhs.getType()).getShape();
|
||||
|
||||
// check shape compatibility, check if we should broadcast
|
||||
// first, we should got a new batch shape. Check from (0, nBatchDims)
|
||||
|
@ -266,8 +266,8 @@ public:
|
|||
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &lhs,
|
||||
Value &rhs, Value &output) const {
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
@ -370,10 +370,10 @@ public:
|
|||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
|
||||
rhs = adaptor.getOther();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
|
@ -393,10 +393,10 @@ public:
|
|||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
|
||||
rhs = adaptor.getMat2();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
|
@ -429,10 +429,10 @@ public:
|
|||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
lhs = adaptor.getInput();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
|
||||
rhs = adaptor.getWeight();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
|
@ -464,16 +464,15 @@ public:
|
|||
auto biasTy = bias.getType();
|
||||
|
||||
// StableHLO does not mandate that elementwise op tensors need to be ranked.
|
||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
||||
!biasTy.template isa<RankedTensorType>())
|
||||
if (!isa<Torch::NoneType>(biasTy) && !isa<RankedTensorType>(biasTy))
|
||||
return op.emitError("only ranked tensor types are supported in StableHLO "
|
||||
"matmul for bias tensor");
|
||||
|
||||
// weight.T
|
||||
rhs = getPermutedTensor(rewriter, op, rhs, {1, 0});
|
||||
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
||||
rhsTy.getRank() - lhsTy.getRank());
|
||||
|
||||
|
@ -503,7 +502,7 @@ public:
|
|||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(biasTy)) {
|
||||
// Bias addition broadcasts to the matmul output shape.
|
||||
matmulPlusBias = rewriter
|
||||
.create<chlo::BroadcastAddOp>(
|
||||
|
@ -525,7 +524,7 @@ public:
|
|||
|
||||
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
|
||||
Value weight, int64_t groups) const {
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto weightElemTy = weightTy.getElementType();
|
||||
auto rank = weightTy.getRank();
|
||||
const auto &options = getOptions();
|
||||
|
@ -588,8 +587,8 @@ public:
|
|||
ArrayRef<int64_t> dilation,
|
||||
ArrayRef<int64_t> outputPadding,
|
||||
int64_t groups) const {
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto weightShape = weightTy.getShape();
|
||||
|
||||
auto nDims = inputTy.getRank();
|
||||
|
@ -727,11 +726,11 @@ public:
|
|||
Value weight = adaptor.getWeight();
|
||||
|
||||
// The input shape is [N, C, H, W]
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
// The weight shape is [OC, (IC//G), KH, KW]
|
||||
// If transposed is set to true,
|
||||
// the weight shape changes to [IC, (OC//G), KH, KW]
|
||||
auto weightTy = weight.getType().template cast<RankedTensorType>();
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
|
@ -819,11 +818,11 @@ public:
|
|||
}
|
||||
|
||||
// Handle bias
|
||||
if (!bias.getType().cast<RankedTensorType>()) {
|
||||
if (!cast<RankedTensorType>(bias.getType())) {
|
||||
return op.emitError("bias provided but not a ranked tensor");
|
||||
}
|
||||
|
||||
auto biasTy = bias.getType().cast<RankedTensorType>();
|
||||
auto biasTy = cast<RankedTensorType>(bias.getType());
|
||||
if (!biasTy.getElementType().isIntOrFloat()) {
|
||||
return op.emitError("only floating-point or integer datatype "
|
||||
"legalization for bias supported");
|
||||
|
|
|
@ -81,12 +81,12 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
if (inputRank <= 2) {
|
||||
return op.emitError(
|
||||
|
@ -176,14 +176,14 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outValTy =
|
||||
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
|
||||
auto outIdxTy =
|
||||
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||
|
||||
if (inputRank <= 2) {
|
||||
return op.emitError(
|
||||
|
@ -366,7 +366,7 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = adaptor.getSelf();
|
||||
RankedTensorType inputTy = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
|
||||
Type inputElemTy = inputTy.getElementType();
|
||||
int64_t inputRank = inputTy.getRank();
|
||||
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||
|
@ -539,11 +539,11 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
|||
AtenCumsumOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
inputTy = input.getType().cast<RankedTensorType>();
|
||||
inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto inputShape = inputTy.getShape();
|
||||
|
|
|
@ -126,7 +126,7 @@ static std::optional<ValueRange>
|
|||
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||
ArrayRef<Value> inputShapeVec, int64_t dim,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
AtenArgmaxOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
|
@ -321,7 +321,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
AtenMaxDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
|
@ -410,7 +410,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
AtenSumOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
|
@ -423,7 +423,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
auto dstElemTy = outTy.getElementType();
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
|
@ -626,7 +626,7 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
|
|||
AtenProdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
|
@ -639,7 +639,7 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
|
|||
auto dstElemTy = outTy.getElementType();
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
|
@ -699,7 +699,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
AtenMaxOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
|
@ -762,7 +762,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
|
|||
AtenMinOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
|
@ -825,7 +825,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
AtenSumDimIntListOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
|
@ -838,7 +838,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
auto dstElemTy = outTy.getElementType();
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
|
@ -958,7 +958,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
const TorchToStablehloOptions &options = getOptions();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputType) {
|
||||
return op.emitError(
|
||||
"only ranked tensor input supported in AtenFrobeniusNormDimOp");
|
||||
|
@ -1070,7 +1070,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
|||
const TorchToStablehloOptions &options = getOptions();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputType) {
|
||||
return op.emitError(
|
||||
"only ranked tensor input supported in AtenLinalgVectorNormOp");
|
||||
|
@ -1078,7 +1078,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
|||
int64_t inputRank = inputType.getRank();
|
||||
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto outElemType = outType.getElementType();
|
||||
if (!isa<mlir::FloatType>(outElemType)) {
|
||||
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");
|
||||
|
|
|
@ -144,7 +144,7 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
|||
|
||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||
TensorType outType) {
|
||||
TensorType in_type = input.getType().cast<TensorType>();
|
||||
TensorType in_type = cast<TensorType>(input.getType());
|
||||
|
||||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promotedType =
|
||||
|
@ -162,7 +162,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
|||
// dimension, the dimension sizes must either be equal, one of them is 1, or
|
||||
// one of them does not exist.
|
||||
Operation *op = input.getDefiningOp();
|
||||
TensorType in_type = input.getType().dyn_cast<TensorType>();
|
||||
TensorType in_type = dyn_cast<TensorType>(input.getType());
|
||||
|
||||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promoted_type =
|
||||
|
@ -217,7 +217,7 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
|||
Operation *op, Value value,
|
||||
ArrayRef<int64_t> inpDims,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!valueTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
||||
|
@ -240,7 +240,7 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
|||
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value value,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!valueTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
||||
|
@ -279,7 +279,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
op, "unsqueeze dimensions must be specified in order");
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
auto one = rewriter.create<arith::ConstantOp>(
|
||||
|
|
|
@ -72,7 +72,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
|||
SmallVector<Value, 4> endIndices;
|
||||
SmallVector<Value, 4> strides;
|
||||
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
size_t rank = inputTy.getRank();
|
||||
startIndices.reserve(rank);
|
||||
endIndices.reserve(rank);
|
||||
|
@ -116,7 +116,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
|||
std::optional<Value> stepOpt, int64_t dim,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto loc = op->getLoc();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto rank = inputTy.getRank();
|
||||
|
||||
dim = (dim + rank) % rank;
|
||||
|
@ -168,8 +168,7 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rankType =
|
||||
adaptor.getSelf().getType().template dyn_cast<RankedTensorType>();
|
||||
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
if (!rankType)
|
||||
return op.emitError("Only ranked tensor types are currently supported");
|
||||
|
||||
|
@ -233,11 +232,11 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
AtenSliceTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -275,7 +274,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
|||
AtenSqueezeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
|
||||
|
@ -318,7 +317,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
|||
AtenSqueezeDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
|
||||
|
@ -369,7 +368,7 @@ template <>
|
|||
LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||
AtenUnsqueezeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType) {
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
}
|
||||
|
@ -378,7 +377,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("dim must be a Scalar constant");
|
||||
int64_t inputRank =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
dim = toPositiveDim(dim, inputRank + 1);
|
||||
if (!isValidDim(dim, inputRank + 1))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
@ -397,7 +396,7 @@ template <>
|
|||
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
||||
PrimsCollapseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getA().getType());
|
||||
if (!selfType) {
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
}
|
||||
|
|
|
@ -89,8 +89,8 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
|
|||
Value indices, Value src,
|
||||
int64_t dim) {
|
||||
// Get information on types for inputs
|
||||
RankedTensorType indexType = indices.getType().cast<RankedTensorType>();
|
||||
RankedTensorType srcSelf = src.getType().cast<RankedTensorType>();
|
||||
RankedTensorType indexType = cast<RankedTensorType>(indices.getType());
|
||||
RankedTensorType srcSelf = cast<RankedTensorType>(src.getType());
|
||||
|
||||
// Store location for insertions
|
||||
Location loc = src.getLoc();
|
||||
|
@ -219,7 +219,7 @@ static Value createTMTensorScatterOp(
|
|||
llvm::ArrayRef<int64_t> dimensionsMap, bool uniqueIndices,
|
||||
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
||||
auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap);
|
||||
auto originalTensorType = original.getType().cast<RankedTensorType>();
|
||||
auto originalTensorType = cast<RankedTensorType>(original.getType());
|
||||
Type originalElementType = originalTensorType.getElementType();
|
||||
auto scatterOp = b.create<TMTensor::ScatterOp>(
|
||||
loc, originalTensorType, ValueRange{updates, indices},
|
||||
|
@ -241,8 +241,8 @@ static Value createTMTensorScanOp(
|
|||
OpBuilder &b, Location loc, Value input, Value output, Value accumulator,
|
||||
int64_t dim, bool inclusive,
|
||||
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto accType = accumulator.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto accType = cast<RankedTensorType>(accumulator.getType());
|
||||
Type elementType = inputType.getElementType();
|
||||
auto scanOp = b.create<TMTensor::ScanOp>(
|
||||
loc, TypeRange{inputType, accType}, input,
|
||||
|
@ -287,7 +287,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
|||
|
||||
// Step 3. Create comparison op which will be used as the sorting predicate.
|
||||
Value compareOp;
|
||||
if (auto intType = elementTypes[0].dyn_cast<mlir::IntegerType>()) {
|
||||
if (auto intType = dyn_cast<mlir::IntegerType>(elementTypes[0])) {
|
||||
// Case for using arith::CmpIOp.
|
||||
arith::CmpIPredicate ge = arith::CmpIPredicate::sge;
|
||||
arith::CmpIPredicate le = arith::CmpIPredicate::sle;
|
||||
|
@ -329,9 +329,9 @@ public:
|
|||
Value index = adaptor.getIndex();
|
||||
Value src = adaptor.getSrc();
|
||||
|
||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||
RankedTensorType indexType = index.getType().cast<RankedTensorType>();
|
||||
RankedTensorType srcType = src.getType().cast<RankedTensorType>();
|
||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||
RankedTensorType indexType = cast<RankedTensorType>(index.getType());
|
||||
RankedTensorType srcType = cast<RankedTensorType>(src.getType());
|
||||
if (selfType.getRank() != indexType.getRank() ||
|
||||
indexType.getRank() != srcType.getRank())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -385,7 +385,7 @@ public:
|
|||
// TODO: Add a check to verify that the input tensor elements are all
|
||||
// non-negative.
|
||||
// Check whether the input is a 1-d tensor of integer type or not.
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType.getRank() != 1 ||
|
||||
!inputType.getElementType().isa<mlir::IntegerType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -394,7 +394,7 @@ public:
|
|||
|
||||
// Check whether the input tensor element type is i64 or not.
|
||||
IntegerType inputIntegerType =
|
||||
inputType.getElementType().cast<IntegerType>();
|
||||
cast<IntegerType>(inputType.getElementType());
|
||||
if (inputIntegerType.getWidth() != 64)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
|
@ -409,7 +409,7 @@ public:
|
|||
SmallVector<int64_t> maxTensorSizes;
|
||||
ValueTensorType maxTensorType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef(maxTensorSizes),
|
||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
||||
cast<ValueTensorType>(torchTypeInput.getType()).getDtype());
|
||||
Value maxTensor =
|
||||
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
|
||||
maxTensor = typeConverter->materializeTargetConversion(
|
||||
|
@ -432,7 +432,7 @@ public:
|
|||
makeShapeTorchCompatible(inputType.getShape())[0], 1};
|
||||
ValueTensorType expandInputType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef(expandedInputSizes),
|
||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
||||
cast<ValueTensorType>(torchTypeInput.getType()).getDtype());
|
||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value expandedInputTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||
|
@ -571,7 +571,7 @@ Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
|
|||
}
|
||||
|
||||
BaseTensorType unsqueezedTensorType =
|
||||
indices[0].getType().cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(indices[0].getType());
|
||||
Value indicesTorchList = b.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(unsqueezedTensorType), indices);
|
||||
llvm::SmallVector<int64_t, 2> concatShape{
|
||||
|
@ -691,7 +691,7 @@ public:
|
|||
auto inputType = cast<ValueTensorType>(input.getType());
|
||||
auto valuesType = cast<ValueTensorType>(values.getType());
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
auto valuesTensorType = op.getValues().getType().cast<BaseTensorType>();
|
||||
auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType());
|
||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
|
@ -902,9 +902,9 @@ public:
|
|||
Value gradOutput = adaptor.getGradOutput();
|
||||
Value input = adaptor.getSelf();
|
||||
RankedTensorType gradOutputType =
|
||||
gradOutput.getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(gradOutput.getType());
|
||||
Type gradOutputElemType = gradOutputType.getElementType();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
Type inputElemType = inputType.getElementType();
|
||||
int64_t tensorOperandRank = inputType.getRank();
|
||||
|
||||
|
@ -914,7 +914,7 @@ public:
|
|||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||
indices = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||
RankedTensorType indicesType = indices.getType().cast<RankedTensorType>();
|
||||
RankedTensorType indicesType = cast<RankedTensorType>(indices.getType());
|
||||
Type indicesElemType = indicesType.getElementType();
|
||||
|
||||
// The element type of the `input` and `grad_output` should be same.
|
||||
|
@ -1100,11 +1100,11 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
|
||||
RankedTensorType selfType =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
RankedTensorType indexType =
|
||||
adaptor.getIndex().getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(adaptor.getIndex().getType());
|
||||
RankedTensorType srcType =
|
||||
adaptor.getSrc().getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(adaptor.getSrc().getType());
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
|
||||
|
@ -1324,7 +1324,7 @@ public:
|
|||
|
||||
// Step 1. Fetch Input to sort.
|
||||
Value inputTensor = adaptor.getSelf();
|
||||
auto inputType = inputTensor.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(inputTensor.getType());
|
||||
unsigned inputRank = inputType.getRank();
|
||||
|
||||
// Step 2. Fetch dimension to perform sort in.
|
||||
|
@ -1414,7 +1414,7 @@ public:
|
|||
.cast<RankedTensorType>();
|
||||
Type elementType = resultType.getElementType();
|
||||
Type inputElementType =
|
||||
input.getType().cast<RankedTensorType>().getElementType();
|
||||
cast<RankedTensorType>(input.getType()).getElementType();
|
||||
|
||||
// Converting the input element type to the result's element type.
|
||||
// The only possible mismatch would be when the input element type is an
|
||||
|
@ -1486,7 +1486,7 @@ public:
|
|||
Value isCausal = op.getIsCausal();
|
||||
Value scale = op.getScale();
|
||||
Type elementType =
|
||||
adaptor.getQuery().getType().cast<ShapedType>().getElementType();
|
||||
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
|
||||
|
||||
// Verify inputs (only support defaults)
|
||||
if (!mask.getType().isa<Torch::NoneType>())
|
||||
|
@ -1557,10 +1557,9 @@ public:
|
|||
key = collapseBatch(key);
|
||||
value = collapseBatch(value);
|
||||
|
||||
SmallVector<int64_t> outSizes(
|
||||
query.getType().cast<ShapedType>().getShape());
|
||||
SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
|
||||
SmallVector<int64_t> valueSizes(
|
||||
value.getType().cast<ShapedType>().getShape());
|
||||
cast<ShapedType>(value.getType()).getShape());
|
||||
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
||||
SmallVector<Value> outSizesDynamic(
|
||||
getTensorSizes(rewriter, op.getLoc(), query));
|
||||
|
|
|
@ -79,9 +79,9 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto operand = adaptor.getOperands()[0];
|
||||
auto operandTy = operand.getType().cast<RankedTensorType>();
|
||||
auto operandTy = cast<RankedTensorType>(operand.getType());
|
||||
auto resultTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
int64_t rank = operandTy.getRank();
|
||||
if (rank == 0) {
|
||||
|
|
|
@ -43,7 +43,7 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -93,9 +93,9 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
auto lhsTy = cast<TensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getOther();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
auto rhsTy = cast<TensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -235,15 +235,15 @@ public:
|
|||
// alpha : scalar: i32/i64/f32
|
||||
// output: tensor: tensor<i32/i64/f32>
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getOther();
|
||||
auto rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
auto rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||
|
||||
if (!lhsType)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
||||
if (auto lhsElemTy = lhsType.getElementType().dyn_cast<IntegerType>()) {
|
||||
if (auto lhsElemTy = dyn_cast<IntegerType>(lhsType.getElementType())) {
|
||||
if (lhsElemTy.getWidth() > 64)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Integers with widths greater than 64 are not supported");
|
||||
|
@ -284,7 +284,7 @@ public:
|
|||
op->getLoc(),
|
||||
RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs);
|
||||
// reinitialize right value type to tensor<i32/f32>
|
||||
rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||
}
|
||||
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||
|
||||
|
@ -337,9 +337,9 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
auto lhsTy = dyn_cast<TensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getOther();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
auto rhsTy = dyn_cast<TensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -409,7 +409,7 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
||||
|
||||
if (!lhsType)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -430,7 +430,7 @@ public:
|
|||
} else {
|
||||
Value rhsAsTensor;
|
||||
Value rhs = adaptor.getOther();
|
||||
auto rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
auto rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||
if (!rhsType) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
|
||||
rhsAsTensor, outElemTy, {}))) {
|
||||
|
@ -469,9 +469,9 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
auto lhsTy = dyn_cast<TensorType>(lhs.getType());
|
||||
Value rhs = adaptor.getOther();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
auto rhsTy = dyn_cast<TensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -497,7 +497,7 @@ public:
|
|||
|
||||
// auto result;
|
||||
Value result;
|
||||
if (outType.getElementType().template isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(outType.getElementType())) {
|
||||
// The input to the reciprocal is an integer sometimes, and we may need to
|
||||
// promote it to a floating point. Per TOSA specification, the input types
|
||||
// can only be floating point for tosa::ReciprocalOp.
|
||||
|
@ -538,7 +538,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
|||
AtenTanhOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
|
@ -555,7 +555,7 @@ LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
|
|||
AtenSigmoidOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
|
@ -572,7 +572,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
AtenReluOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
|
||||
// Maps to tosa.clamp which has both int and fp limits.
|
||||
int64_t clampMin = 0;
|
||||
|
@ -602,7 +602,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization currently supported");
|
||||
|
@ -660,7 +660,7 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -713,7 +713,7 @@ class ConvertAtenMultipleDimsReductionOp
|
|||
"non-const dim parameter unsupported");
|
||||
int64_t N = reduceDims.size();
|
||||
int64_t inputRank =
|
||||
adaptor.getSelf().getType().template cast<RankedTensorType>().getRank();
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
|
||||
if (!isValidDim(reduceDims[i], inputRank))
|
||||
|
@ -751,7 +751,7 @@ class ConvertAtenOneDimReductionOp
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
int64_t inputRank =
|
||||
adaptor.getSelf().getType().template cast<RankedTensorType>().getRank();
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
reduceDim = toPositiveDim(reduceDim, inputRank);
|
||||
if (!isValidDim(reduceDim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
@ -782,7 +782,7 @@ public:
|
|||
ElementsAttr &reduceDimsAttr,
|
||||
bool &keepDims) const override {
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
|
||||
// Select all dims to reduce
|
||||
SmallVector<int64_t, 4> reduceDims;
|
||||
|
@ -804,7 +804,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -835,7 +835,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
// Create a single instance of tosa.argmax.
|
||||
// Multiple dims require chained construct.
|
||||
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
|
||||
SmallVector<int64_t> outputShapeArr = {};
|
||||
int32_t i = 0;
|
||||
|
@ -865,7 +865,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
// Convert the final index to i64 for backend finalization, However, i64
|
||||
// is not a defined type for tosa.cast, so using arith.extsi instead.
|
||||
auto castToInt64 = [&](Value result) -> LogicalResult {
|
||||
auto resTy = result.getType().cast<ShapedType>();
|
||||
auto resTy = cast<ShapedType>(result.getType());
|
||||
if (!resTy)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Argmax: Result is not a shaped type");
|
||||
|
@ -915,7 +915,7 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1010,7 +1010,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1021,7 +1021,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
op, "Only floating-point datatype legalization supported");
|
||||
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
Value expTensor;
|
||||
Value expScalar = op.getExponent();
|
||||
|
@ -1063,8 +1063,8 @@ public:
|
|||
ConversionPatternRewriter &rewriter, Value &lhs,
|
||||
Value &rhs, Value &output) const {
|
||||
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
@ -1097,7 +1097,7 @@ public:
|
|||
// construct the input and output reshaping logic.
|
||||
auto getRankBroadcastedShape = [&](Value tensor,
|
||||
bool isRHS) -> SmallVector<int64_t> {
|
||||
auto tensorTy = tensor.getType().cast<TensorType>();
|
||||
auto tensorTy = cast<TensorType>(tensor.getType());
|
||||
auto tensorShape = makeShapeTorchCompatible(tensorTy.getShape());
|
||||
auto tensorRank = tensorTy.getRank();
|
||||
|
||||
|
@ -1151,7 +1151,7 @@ public:
|
|||
// TOSA matmul is performed on two 3D inputs and generates a 3D output.
|
||||
// Lower ranked tensors are dim-1 reshaped up to 3D
|
||||
auto reshapeUpTo3DTensor = [&](Value tensor) -> Value {
|
||||
auto tensorTy = tensor.getType().cast<TensorType>();
|
||||
auto tensorTy = cast<TensorType>(tensor.getType());
|
||||
auto rank = tensorTy.getRank();
|
||||
|
||||
assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3");
|
||||
|
@ -1440,9 +1440,9 @@ public:
|
|||
}
|
||||
|
||||
auto matmulLhsShape = makeShapeTorchCompatible(
|
||||
matmulLhs.getType().template cast<RankedTensorType>().getShape());
|
||||
cast<RankedTensorType>(matmulLhs.getType()).getShape());
|
||||
auto matmulRhsShape = makeShapeTorchCompatible(
|
||||
matmulRhs.getType().template cast<RankedTensorType>().getShape());
|
||||
cast<RankedTensorType>(matmulRhs.getType()).getShape());
|
||||
|
||||
// The reshape/transpose should ensure the tosa.matmul always has same
|
||||
// batch size for either matrix. If if shapes are dynamic, they'll be
|
||||
|
@ -1642,10 +1642,10 @@ public:
|
|||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
|
||||
rhs = adaptor.getOther();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1666,10 +1666,10 @@ public:
|
|||
Value &lhs, Value &rhs) const override {
|
||||
|
||||
lhs = adaptor.getSelf();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
|
||||
rhs = adaptor.getMat2();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1703,10 +1703,10 @@ public:
|
|||
Value &lhs, Value &rhs) const override {
|
||||
|
||||
lhs = adaptor.getInput();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||
|
||||
rhs = adaptor.getWeight();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1744,14 +1744,13 @@ public:
|
|||
auto biasTy = bias.getType();
|
||||
|
||||
// TOSA does not mandate that elementwise op tensors need to be ranked.
|
||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
||||
!biasTy.template isa<TensorType>())
|
||||
if (!isa<Torch::NoneType>(biasTy) && !isa<TensorType>(biasTy))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types supported in GEMM to TOSA for bias tensor");
|
||||
|
||||
// RHS must have its last two dims transposed prior to matrix
|
||||
// multiplication.
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
auto rhsShape = makeShapeTorchCompatible(rhsTy.getShape());
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
@ -1789,7 +1788,7 @@ public:
|
|||
"Failed to perform matmul operation");
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(biasTy)) {
|
||||
// Bias addition broadcasts to the matmul output shape.
|
||||
matmulPlusBias =
|
||||
rewriter
|
||||
|
@ -1818,7 +1817,7 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
|||
auto otherScalar = op.getOther();
|
||||
auto alphaScalar = op.getAlpha();
|
||||
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Rsub");
|
||||
|
@ -1867,8 +1866,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
auto input = adaptor.getInput();
|
||||
auto weight = adaptor.getWeight();
|
||||
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto outputTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
|
@ -1893,7 +1892,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
// Bias is optional. TOSA mandates a zero tensor here, so construct one if
|
||||
// required.
|
||||
auto bias = adaptor.getBias();
|
||||
if (adaptor.getBias().getType().template isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(adaptor.getBias().getType())) {
|
||||
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
|
||||
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
|
||||
// define a 48-bit int.
|
||||
|
@ -1909,7 +1908,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
.value();
|
||||
}
|
||||
} else {
|
||||
if (!bias.getType().cast<RankedTensorType>())
|
||||
if (!cast<RankedTensorType>(bias.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Bias provided but not a ranked tensor");
|
||||
}
|
||||
|
@ -2115,7 +2114,7 @@ LogicalResult ConvertAtenOp<AtenReshapeOp>::matchAndRewrite(
|
|||
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Reshape");
|
||||
|
@ -2199,7 +2198,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a ranked tensor output
|
||||
if (!adaptor.getInput().getType().dyn_cast<RankedTensorType>())
|
||||
if (!dyn_cast<RankedTensorType>(adaptor.getInput().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types are supported");
|
||||
|
||||
|
@ -2211,8 +2210,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
if (op.getMomentum().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
|
||||
|
||||
auto meanType = adaptor.getRunningMean().getType().dyn_cast<TensorType>();
|
||||
auto varianceType = adaptor.getRunningVar().getType().dyn_cast<TensorType>();
|
||||
auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType());
|
||||
auto varianceType = dyn_cast<TensorType>(adaptor.getRunningVar().getType());
|
||||
if (!varianceType || !meanType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types are supported");
|
||||
|
@ -2225,7 +2224,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
const TypeConverter *converter, Type outType,
|
||||
const Value toBcast, Value &result) {
|
||||
RankedTensorType toBcastType =
|
||||
toBcast.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(toBcast.getType());
|
||||
if (toBcastType.getRank() > 1)
|
||||
return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1");
|
||||
|
||||
|
@ -2298,11 +2297,11 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
// eventually being reshaped for broadcasting.
|
||||
|
||||
// Not a ranked tensor output
|
||||
if (!adaptor.getInput().getType().dyn_cast<RankedTensorType>())
|
||||
if (!dyn_cast<RankedTensorType>(adaptor.getInput().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types are supported");
|
||||
|
||||
auto inputType = adaptor.getInput().getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(adaptor.getInput().getType());
|
||||
if (inputType.getRank() > 4)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Only up to 4D tensors are supported");
|
||||
|
@ -2317,8 +2316,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
if (adaptor.getBias().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
|
||||
|
||||
auto weightType = adaptor.getWeight().getType().cast<RankedTensorType>();
|
||||
auto biasType = adaptor.getBias().getType().cast<RankedTensorType>();
|
||||
auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType());
|
||||
auto biasType = cast<RankedTensorType>(adaptor.getBias().getType());
|
||||
int64_t inputRank = inputType.getRank();
|
||||
Type elemTy = inputType.getElementType();
|
||||
SmallVector<int64_t> inputTypeShape(
|
||||
|
@ -2461,7 +2460,7 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|||
// element type. All tensors with element types other than integer can reuse
|
||||
// existing elements attribute.
|
||||
// TODO: what about unsigned integer?
|
||||
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
|
||||
if (elements.getElementType().isSignedInteger()) {
|
||||
Type builtinTensorElemTy = outputTy.getElementType();
|
||||
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
||||
|
@ -2483,7 +2482,7 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a ranked tensor type
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
||||
auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Only ranked tensor types supported");
|
||||
|
@ -2548,7 +2547,7 @@ LogicalResult ConvertAtenOp<AtenUnflattenIntOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a ranked tensor type
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
||||
auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
|
@ -2602,7 +2601,7 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a ranked tensor type
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
||||
auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
|
@ -2637,7 +2636,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -2665,7 +2664,7 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -2715,7 +2714,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -2763,7 +2762,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -2781,7 +2780,7 @@ LogicalResult ConvertAtenOp<AtenDropoutOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getInput().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getInput().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -2807,7 +2806,7 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -2869,7 +2868,7 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
|||
//
|
||||
// Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4
|
||||
|
||||
auto outType = x.getType().cast<TensorType>();
|
||||
auto outType = cast<TensorType>(x.getType());
|
||||
auto loc = op->getLoc();
|
||||
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
|
||||
|
@ -2949,7 +2948,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -2986,7 +2985,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -3043,7 +3042,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -3063,7 +3062,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
Value gradOutput = adaptor.getGradOutput();
|
||||
auto gradOutputType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto gradOutputType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
|
||||
Type gradOutputElemType = gradOutputType.getElementType();
|
||||
|
||||
|
@ -3119,14 +3118,14 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|||
Value weight = adaptor.getWeight();
|
||||
Value indices = adaptor.getIndices();
|
||||
RankedTensorType outType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
auto indicesType = indices.getType().dyn_cast<RankedTensorType>();
|
||||
auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
|
||||
if (!indicesType || !indicesType.getElementType().isa<IntegerType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Indices must be of integer tensor type");
|
||||
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
auto weightType = cast<RankedTensorType>(weight.getType());
|
||||
if (weightType.getRank() != 2)
|
||||
return op.emitError("weight must be of rank 2");
|
||||
|
||||
|
@ -3216,7 +3215,7 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
|
|||
AtenTransposeIntOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
|
@ -3258,12 +3257,12 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
AtenMaxDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
auto indicesType =
|
||||
getTypeConverter()->convertType(op.getType(1)).dyn_cast<TensorType>();
|
||||
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||
if (!indicesType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
|
@ -3334,7 +3333,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
AtenSliceTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
@ -3406,7 +3405,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
@ -3500,13 +3499,13 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
|
||||
// Not a tensor type.
|
||||
auto input = adaptor.getSelf();
|
||||
auto inputType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
if (!inputType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only RankedTensorType input are currently supported");
|
||||
|
||||
auto index = adaptor.getIndex();
|
||||
auto indexType = adaptor.getIndex().getType().dyn_cast<RankedTensorType>();
|
||||
auto indexType = dyn_cast<RankedTensorType>(adaptor.getIndex().getType());
|
||||
auto inputShape = inputType.getShape();
|
||||
int paramsRank = inputShape.size();
|
||||
|
||||
|
@ -3593,13 +3592,13 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
|||
|
||||
// Not a tensor type.
|
||||
auto input = adaptor.getSelf();
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
||||
auto fillValues = adaptor.getValues();
|
||||
auto valuesType = adaptor.getValues().getType().dyn_cast<TensorType>();
|
||||
auto valuesType = dyn_cast<TensorType>(adaptor.getValues().getType());
|
||||
if (!valuesType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
@ -3640,7 +3639,7 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Multiple None index is not support for now.");
|
||||
}
|
||||
auto indexNextType = indexNext.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexNextType = dyn_cast<RankedTensorType>(indexNext.getType());
|
||||
auto indexNextShape = indexNextType.getShape();
|
||||
|
||||
int64_t size = 1;
|
||||
|
@ -3652,7 +3651,7 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
|||
.value();
|
||||
}
|
||||
|
||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
auto indexShape = indexType.getShape();
|
||||
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
||||
indexesRank.push_back(indexType.getRank());
|
||||
|
@ -3734,7 +3733,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
// [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]]
|
||||
auto input = adaptor.getSelf();
|
||||
auto inputTensorType =
|
||||
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
// Check input is a tensor type.
|
||||
if (!inputTensorType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -3771,7 +3770,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
auto index = indexTensors[i];
|
||||
|
||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
auto indexShape = indexType.getShape();
|
||||
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
||||
indexesRank.push_back(indexType.getRank());
|
||||
|
@ -3837,7 +3836,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
|
||||
// Support for multiple index
|
||||
auto index = indexTensors[0];
|
||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
auto indexShape = indexType.getShape();
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
|
@ -3879,7 +3878,7 @@ LogicalResult ConvertAtenOp<AtenAbsOp>::matchAndRewrite(
|
|||
AtenAbsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
@ -3896,11 +3895,11 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
auto condType = adaptor.getCondition().getType().dyn_cast<TensorType>();
|
||||
auto condType = dyn_cast<TensorType>(adaptor.getCondition().getType());
|
||||
if (!condType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types condition are currently supported");
|
||||
|
@ -3919,11 +3918,11 @@ LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
|
||||
auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
|
||||
if (!otherType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types condition are currently supported");
|
||||
|
@ -3955,8 +3954,8 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
|
|||
op, "unimplemented: equal_nan is expected to be false");
|
||||
|
||||
// check tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
|
||||
if (!selfType || !otherType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
@ -3998,7 +3997,7 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only tensor types input are currently supported");
|
||||
|
@ -4251,8 +4250,8 @@ LogicalResult ConvertAtenOp<AtenCopyOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto srcType = adaptor.getSrc().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
auto srcType = dyn_cast<TensorType>(adaptor.getSrc().getType());
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
@ -4297,7 +4296,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
@ -4355,14 +4354,14 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Remainder");
|
||||
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
|
@ -4438,7 +4437,7 @@ public:
|
|||
// Apply the transposeDims vector on input to generate a transposed form.
|
||||
Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter,
|
||||
Value input, ArrayRef<int32_t> transposeDims) const {
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
|
||||
auto inputRank = inputTy.getRank();
|
||||
|
@ -4462,8 +4461,7 @@ public:
|
|||
Value transposePoolingInputToHwc(AtenOpT op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value input) const {
|
||||
auto inputRank =
|
||||
input.getType().template cast<RankedTensorType>().getRank();
|
||||
auto inputRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
|
||||
SmallVector<int32_t> nchwToNhwc4DTransposeDims({0, 2, 3, 1});
|
||||
SmallVector<int32_t> chwToHwc3DTransposeDims({1, 2, 0});
|
||||
|
@ -4476,7 +4474,7 @@ public:
|
|||
Value transposePoolingOutputToChw(AtenOpT op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value input) const {
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputRank = inputTy.getRank();
|
||||
|
||||
SmallVector<int32_t> nhwcToNchw4DTransposeDims({0, 3, 1, 2});
|
||||
|
@ -4547,7 +4545,7 @@ public:
|
|||
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||
Type &outputTy) const override {
|
||||
auto inputXchw = adaptor.getSelf();
|
||||
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
|
||||
auto inputTy = cast<RankedTensorType>(inputXchw.getType());
|
||||
if (!inputTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Adaptive avgpool requires ranked tensor input");
|
||||
|
@ -4659,7 +4657,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
|
|||
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
|
||||
DenseI64ArrayAttr &pad) {
|
||||
|
||||
RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>();
|
||||
RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
|
||||
if (!inputTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Pooling op requires ranked tensor input");
|
||||
|
@ -4797,7 +4795,7 @@ public:
|
|||
// FIXME: Handle layout, device and pin_memory. Assume dtype has been
|
||||
// processed to set output type correctly?
|
||||
// The layout arg should be either `none` or `0` i.e. strided.
|
||||
if (!op.getLayout().getType().template isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||
int64_t tensorLayout;
|
||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4808,7 +4806,7 @@ public:
|
|||
}
|
||||
|
||||
bool pinMemory;
|
||||
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
|
||||
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4892,19 +4890,19 @@ public:
|
|||
}
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().template dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType || !outType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Only tensor types with static shapes input are currently supported");
|
||||
|
||||
auto maskType = adaptor.getMask().getType().template dyn_cast<TensorType>();
|
||||
auto maskType = dyn_cast<TensorType>(adaptor.getMask().getType());
|
||||
if (!maskType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types mask are currently supported");
|
||||
|
||||
Value rhs = adaptor.getValue();
|
||||
auto rhsType = rhs.getType().template dyn_cast<TensorType>();
|
||||
auto rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||
Value rhsAsTensor;
|
||||
if (!rhsType) { // scalar
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(),
|
||||
|
@ -4913,11 +4911,11 @@ public:
|
|||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
} else { // tensor
|
||||
rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||
}
|
||||
|
||||
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||
auto rhsTensorType = rhsTensor.getType().template dyn_cast<TensorType>();
|
||||
auto rhsTensorType = dyn_cast<TensorType>(rhsTensor.getType());
|
||||
if (rhsTensorType.getElementType() != outElemTy)
|
||||
rhsTensor = rewriter.create<tosa::CastOp>(
|
||||
op.getLoc(),
|
||||
|
@ -4940,7 +4938,7 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
int64_t memoryFormat;
|
||||
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
|
||||
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()) &&
|
||||
(!matchPattern(op.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)) ||
|
||||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
|
@ -4964,7 +4962,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
auto selfElemTy = selfTy.getElementType();
|
||||
int64_t rank = selfTy.getRank();
|
||||
|
||||
|
@ -5033,7 +5031,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
auto outType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
int64_t rank = outType.getRank();
|
||||
int64_t dim;
|
||||
|
||||
|
@ -5074,7 +5072,7 @@ LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
|
|||
|
||||
// Converts AtenSqrtOp into pow(x, 0.5)
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().dyn_cast<TensorType>();
|
||||
auto selfTy = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
|
|
@ -117,8 +117,8 @@ template <>
|
|||
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||
Operation *op, TensorType outType,
|
||||
Value lhs, Value rhs) {
|
||||
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
|
||||
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
|
||||
auto lhsElemTy = cast<TensorType>(lhs.getType()).getElementType();
|
||||
auto rhsElemTy = cast<TensorType>(rhs.getType()).getElementType();
|
||||
if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
|
||||
(void)rewriter.notifyMatchFailure(op,
|
||||
"tosa.div only supports integer type");
|
||||
|
@ -148,8 +148,8 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
|||
// [2,1] [[0, 3, 2],[0, 3, 1]]
|
||||
// ]] 1*4*2 ]] 1*4*2*3
|
||||
|
||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexType = indexValue.getType().dyn_cast<RankedTensorType>();
|
||||
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
|
||||
auto indexType = dyn_cast<RankedTensorType>(indexValue.getType());
|
||||
auto paramsShape = paramsType.getShape(); // [1 4 3]
|
||||
auto indexShape = indexType.getShape(); // [1 4 2]
|
||||
int paramsRank = paramsShape.size(); // 3
|
||||
|
@ -214,8 +214,8 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
|
|||
Type outType, Value paramsValue,
|
||||
Value indicesValue) {
|
||||
auto resultType = dyn_cast<ShapedType>(outType);
|
||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
||||
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
||||
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
|
||||
auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());
|
||||
|
||||
if (!resultType || !paramsType || !indicesType)
|
||||
return std::nullopt;
|
||||
|
@ -420,9 +420,9 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
|||
Value paramsValue, Value indicesValue,
|
||||
Value fillValues) {
|
||||
auto resultType = dyn_cast<ShapedType>(outType);
|
||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
||||
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
||||
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
|
||||
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
|
||||
auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());
|
||||
auto fillValuesType = dyn_cast<RankedTensorType>(fillValues.getType());
|
||||
|
||||
if (!resultType || !paramsType || !indicesType)
|
||||
return std::nullopt;
|
||||
|
@ -572,7 +572,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
|||
tosaFillValuesTileOp.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape));
|
||||
fillValues = newTosaFillValuesReshapeOp.getResult();
|
||||
fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
|
||||
fillValuesType = dyn_cast<RankedTensorType>(fillValues.getType());
|
||||
}
|
||||
|
||||
// fillK: range of each index, total number of fillInput(could be scatter)
|
||||
|
@ -691,7 +691,7 @@ std::optional<Value> convertReduceOpCommon(
|
|||
Type reduce_element_type, bool is_quantized, double input_scale,
|
||||
int64_t input_zp, double output_scale, int64_t output_zp) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
@ -754,7 +754,7 @@ convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
@ -769,7 +769,7 @@ convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
@ -784,7 +784,7 @@ convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
@ -799,7 +799,7 @@ convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
@ -814,7 +814,7 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
@ -840,7 +840,7 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
@ -863,9 +863,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
|||
|
||||
if (input_is_qtype) {
|
||||
auto input_qtype =
|
||||
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||||
cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||
auto output_qtype =
|
||||
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||||
cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||
|
||||
int32_t input_shift = 20;
|
||||
|
||||
|
@ -895,7 +895,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
|||
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
|
||||
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
@ -940,9 +940,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
|||
|
||||
if (input_is_qtype) {
|
||||
auto input_qtype =
|
||||
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||||
cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||
auto output_qtype =
|
||||
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||||
cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||
|
||||
// Combine 'div_scale' as part of output rescale
|
||||
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
|
||||
|
@ -976,7 +976,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
dyn_cast<RankedTensorType>(input_value.getType());
|
||||
if (!input_type)
|
||||
return std::nullopt;
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
|
|||
Value input_val, double input_scale,
|
||||
int64_t input_zp) {
|
||||
// Output is always int32 type
|
||||
auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
|
||||
auto input_type = dyn_cast<mlir::ShapedType>(input_val.getType());
|
||||
assert(input_type);
|
||||
auto output_type = input_type.clone(rewriter.getI32Type());
|
||||
|
||||
|
@ -58,9 +58,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
|
|||
Value conv_val, ShapedType input_type,
|
||||
ShapedType weight_type, ShapedType output_type) {
|
||||
auto input_qtype =
|
||||
input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
|
||||
auto output_qtype = output_type.getElementType()
|
||||
.dyn_cast<mlir::quant::UniformQuantizedType>();
|
||||
dyn_cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||
auto output_qtype =
|
||||
dyn_cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||
|
||||
double input_scale = input_qtype.getScale();
|
||||
|
||||
|
@ -71,8 +71,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
|
|||
int32_t scale_width = scale32 ? 32 : 16;
|
||||
|
||||
if (auto weight_per_tensor_qtype =
|
||||
weight_type.getElementType()
|
||||
.dyn_cast<mlir::quant::UniformQuantizedType>()) {
|
||||
dyn_cast<mlir::quant::UniformQuantizedType>(
|
||||
weight_type.getElementType())) {
|
||||
// Per-tensor quantization
|
||||
double weight_scale = weight_per_tensor_qtype.getScale();
|
||||
|
||||
|
@ -94,8 +94,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
|
|||
return rescale_op.getResult();
|
||||
|
||||
} else if (auto weight_per_channel_qtype =
|
||||
weight_type.getElementType()
|
||||
.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
|
||||
dyn_cast<mlir::quant::UniformQuantizedPerAxisType>(
|
||||
weight_type.getElementType())) {
|
||||
// Per-channel quantization
|
||||
SmallVector<int32_t> multiplier_arr;
|
||||
SmallVector<int8_t> shift_arr;
|
||||
|
@ -311,7 +311,7 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
|
|||
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||
Value src, Type destType, Value &result) {
|
||||
|
||||
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
|
||||
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
|
||||
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
|
||||
|
||||
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
|
||||
|
@ -319,7 +319,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
|||
op, "casting to result dtype is invalid or unsupported");
|
||||
|
||||
if (destElemTy.isInteger(1)) {
|
||||
auto srcType = src.getType().dyn_cast<TensorType>();
|
||||
auto srcType = dyn_cast<TensorType>(src.getType());
|
||||
SmallVector<int64_t> srcShape(srcType.getShape());
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : srcShape)
|
||||
|
@ -355,7 +355,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
|||
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
||||
Operation *op = input.getDefiningOp();
|
||||
TensorType inType = input.getType().cast<TensorType>();
|
||||
TensorType inType = cast<TensorType>(input.getType());
|
||||
|
||||
if (inType.getElementType() != outType.getElementType()) {
|
||||
TensorType promotedType =
|
||||
|
|
|
@ -52,7 +52,7 @@ LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
|
|||
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
|
||||
Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
||||
Value inputRank) {
|
||||
assert(dim.getType().isa<IntegerType>() &&
|
||||
assert(isa<IntegerType>(dim.getType()) &&
|
||||
"dim arg of toPositiveDim must be integer type");
|
||||
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
|
||||
Value cst0 =
|
||||
|
@ -132,7 +132,7 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|||
Type elemTy) {
|
||||
Value initTensor =
|
||||
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||
RankedTensorType type = initTensor.getType().cast<RankedTensorType>();
|
||||
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
|
||||
Value c0 =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
|
||||
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
|
@ -172,7 +172,7 @@ Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
|
|||
|
||||
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||
Value tensor, int dim) {
|
||||
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
||||
RankedTensorType type = cast<RankedTensorType>(tensor.getType());
|
||||
assert(dim < type.getRank() &&
|
||||
"The given dim must be smaller than tensor rank");
|
||||
(void)type;
|
||||
|
@ -183,7 +183,7 @@ SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, Value tensor) {
|
||||
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
||||
RankedTensorType type = cast<RankedTensorType>(tensor.getType());
|
||||
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
||||
}
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v,
|
|||
|
||||
OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
||||
int64_t dim) {
|
||||
auto t = v.getType().cast<ShapedType>();
|
||||
auto t = cast<ShapedType>(v.getType());
|
||||
if (t.isDynamicDim(dim)) {
|
||||
return getDimValue(builder, loc, v, dim);
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|||
static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
|
||||
Value rhs, ValueRange rhsSizes, Value output,
|
||||
ValueRange outputSizes, bool transposed = false) {
|
||||
auto elementType = lhs.getType().cast<MemRefType>().getElementType();
|
||||
auto elementType = cast<MemRefType>(lhs.getType()).getElementType();
|
||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto rank = outputSizes.size();
|
||||
|
@ -168,9 +168,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
Value key = getKey();
|
||||
Value value = getValue();
|
||||
Value output = getOutput();
|
||||
auto queryType = query.getType().cast<MemRefType>();
|
||||
auto keyType = key.getType().cast<MemRefType>();
|
||||
auto valueType = value.getType().cast<MemRefType>();
|
||||
auto queryType = cast<MemRefType>(query.getType());
|
||||
auto keyType = cast<MemRefType>(key.getType());
|
||||
auto valueType = cast<MemRefType>(value.getType());
|
||||
auto queryRank = queryType.getRank();
|
||||
auto keyRank = keyType.getRank();
|
||||
auto valueRank = valueType.getRank();
|
||||
|
@ -330,12 +330,12 @@ LogicalResult ScanOp::verify() {
|
|||
if (getNumOutputs() != 2) {
|
||||
return emitOpError("expected two output operands");
|
||||
}
|
||||
if (!input().getType().isa<ShapedType>()) {
|
||||
if (!isa<ShapedType>(input().getType())) {
|
||||
return emitOpError("expected first input element type to be shaped");
|
||||
}
|
||||
auto accumulatorType = accumulator().getType().cast<ShapedType>();
|
||||
auto inputType = input().getType().cast<ShapedType>();
|
||||
auto outputType = output().getType().cast<ShapedType>();
|
||||
auto accumulatorType = cast<ShapedType>(accumulator().getType());
|
||||
auto inputType = cast<ShapedType>(input().getType());
|
||||
auto outputType = cast<ShapedType>(output().getType());
|
||||
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
||||
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
||||
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
||||
|
@ -706,7 +706,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
|||
loadIndices.push_back(Value());
|
||||
|
||||
// Populate with empty values.
|
||||
auto originalTy = original().getType().cast<ShapedType>();
|
||||
auto originalTy = cast<ShapedType>(original().getType());
|
||||
starts.resize(originalTy.getRank(), Value());
|
||||
auto updateIvs = ivs.drop_front(1);
|
||||
|
||||
|
@ -797,7 +797,7 @@ LogicalResult SortOp::verify() {
|
|||
if (yieldOp.getNumOperands() != 1) {
|
||||
return op->emitOpError("should yield exactly one operand");
|
||||
}
|
||||
auto ty = yieldOp.getOperand(0).getType().dyn_cast<IntegerType>();
|
||||
auto ty = dyn_cast<IntegerType>(yieldOp.getOperand(0).getType());
|
||||
if (!ty || ty.getWidth() != 1) {
|
||||
return op->emitOpError("should yield i1 type");
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ using namespace ::mlir;
|
|||
using namespace ::mlir::torch::TMTensor;
|
||||
|
||||
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
auto memrefType = cast<MemRefType>(memref.getType());
|
||||
auto alloc = b.create<memref::AllocOp>(
|
||||
loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType());
|
||||
b.create<memref::CopyOp>(loc, memref, alloc);
|
||||
|
|
|
@ -80,7 +80,7 @@ struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
|
|||
return failure();
|
||||
}
|
||||
if (llvm::any_of(scalarLoopOp->getResults(),
|
||||
[&](Value v) { return v.getType().isa<ShapedType>(); })) {
|
||||
[&](Value v) { return isa<ShapedType>(v.getType()); })) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
scalarLoopOp, "lower to loops needs to have tensor semantics");
|
||||
}
|
||||
|
|
|
@ -122,14 +122,14 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
|||
auto func = dyn_cast<func::FuncOp>(op);
|
||||
if (!func)
|
||||
return op->emitError() << "'torch.type_bound' must be attached to a func";
|
||||
TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>();
|
||||
TypeAttr attr = dyn_cast<TypeAttr>(namedAttr.getValue());
|
||||
if (!attr)
|
||||
return op->emitError() << "'torch.type_bound' must be TypeAttr";
|
||||
auto type = attr.getValue().dyn_cast<BaseTensorType>();
|
||||
auto type = dyn_cast<BaseTensorType>(attr.getValue());
|
||||
if (!type)
|
||||
return op->emitError() << "'torch.type_bound' must be of "
|
||||
"!torch.tensor/!torch.vtensor type";
|
||||
if (!func.getFunctionType().getInput(argIndex).isa<BaseTensorType>())
|
||||
if (!isa<BaseTensorType>(func.getFunctionType().getInput(argIndex)))
|
||||
return op->emitError() << "'torch.type_bound' must be attached to an "
|
||||
"argument of !torch.tensor/!torch.vtensor type";
|
||||
return success();
|
||||
|
|
|
@ -75,7 +75,7 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
|||
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||
BaseTensorType newType,
|
||||
Value tensor) {
|
||||
auto originalType = tensor.getType().cast<BaseTensorType>();
|
||||
auto originalType = cast<BaseTensorType>(tensor.getType());
|
||||
// Adjust the static information in the type to match between the original and
|
||||
// new types.
|
||||
if (!originalType.hasSameSizesAndDtype(newType)) {
|
||||
|
@ -87,7 +87,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
|||
// up creating one op that converts between the value and non-value tensor
|
||||
// domains. If both the original and new types are both non-value tensors,
|
||||
// then we do the copy by going to a value tensor and back.
|
||||
if (tensor.getType().isa<NonValueTensorType>())
|
||||
if (isa<NonValueTensorType>(tensor.getType()))
|
||||
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
||||
if (isa<NonValueTensorType>(newType))
|
||||
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
||||
|
@ -96,7 +96,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
|||
}
|
||||
|
||||
bool mlir::torch::Torch::isListPotentiallyMutated(Value list) {
|
||||
assert(list.getType().isa<Torch::ListType>());
|
||||
assert(isa<Torch::ListType>(list.getType()));
|
||||
return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands);
|
||||
}
|
||||
|
||||
|
@ -148,8 +148,7 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
return nullptr;
|
||||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseIntElementsAttr>()
|
||||
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
|
@ -777,7 +776,7 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
|||
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
||||
if (getOperand(0).getType() != getResult().getType())
|
||||
return nullptr;
|
||||
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
||||
if (auto tensorType = dyn_cast<BaseTensorType>(getOperand(0).getType())) {
|
||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||
return getOperand(0);
|
||||
}
|
||||
|
@ -798,11 +797,11 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
|
|||
if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg)
|
||||
return nullptr;
|
||||
// The memory_format arg must be `none`.
|
||||
if (!getMemoryFormat().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
|
||||
return nullptr;
|
||||
|
||||
auto inputType = getSelf().getType().cast<BaseTensorType>();
|
||||
auto resType = getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(getSelf().getType());
|
||||
auto resType = cast<BaseTensorType>(getType());
|
||||
// If the types aren't equal, then we can't fold.
|
||||
if (inputType != resType)
|
||||
return nullptr;
|
||||
|
@ -821,7 +820,7 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
||||
// The pin_memory arg should be either constant `False` or `none`.
|
||||
if (!getPinMemory().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(getPinMemory().getType())) {
|
||||
bool pinMemory;
|
||||
if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
||||
return nullptr;
|
||||
|
@ -844,15 +843,15 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
|
||||
// The device arg must be `none`.
|
||||
if (!getDevice().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(getDevice().getType()))
|
||||
return nullptr;
|
||||
|
||||
// The memory_format arg must be `none`.
|
||||
if (!getMemoryFormat().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
|
||||
return nullptr;
|
||||
|
||||
auto inputType = getSelf().getType().cast<BaseTensorType>();
|
||||
auto resType = getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(getSelf().getType());
|
||||
auto resType = cast<BaseTensorType>(getType());
|
||||
// If the types aren't equal, then we can't fold.
|
||||
if (inputType != resType)
|
||||
return nullptr;
|
||||
|
@ -863,7 +862,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
|
||||
// The layout arg should be either `none` or `0` i.e. strided.
|
||||
if (!getLayout().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(getLayout().getType())) {
|
||||
int64_t tensorLayout;
|
||||
if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||
return nullptr;
|
||||
|
@ -882,7 +881,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
|||
// is false
|
||||
patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) {
|
||||
// The pin_memory arg should be either constant `False` or `none`.
|
||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getPinMemory().getType())) {
|
||||
bool pinMemory;
|
||||
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
||||
return failure();
|
||||
|
@ -891,7 +890,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
|||
}
|
||||
|
||||
// The layout arg should be either `none` or `0` i.e. strided.
|
||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||
int64_t tensorLayout;
|
||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||
return failure();
|
||||
|
@ -899,7 +898,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
|||
return failure();
|
||||
}
|
||||
|
||||
if (op.getDevice().getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(op.getDevice().getType())) {
|
||||
// The device arg is `none`. Rewrite to to.dtype.
|
||||
AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>(
|
||||
op.getLoc(), op.getType(), op.getSelf(), op.getDtype(),
|
||||
|
@ -985,10 +984,10 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
||||
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
auto inputType = dyn_cast<BaseTensorType>(getOperand(0).getType());
|
||||
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
auto resType = getType().dyn_cast<BaseTensorType>();
|
||||
auto resType = dyn_cast<BaseTensorType>(getType());
|
||||
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
if (inputType != resType)
|
||||
|
@ -1011,7 +1010,7 @@ OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (auto tensorType = dyn_cast<BaseTensorType>(getOperand().getType())) {
|
||||
if (tensorType.hasSizes())
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||
tensorType.getSizes().size());
|
||||
|
@ -1117,7 +1116,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
}
|
||||
|
||||
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
|
||||
if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(op->getOperand(2).getType())) {
|
||||
// None rounding mode
|
||||
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
|
||||
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
||||
|
@ -1879,9 +1878,9 @@ OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
||||
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||
if (resultType && resultType.hasDtype() &&
|
||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
||||
isa<mlir::IntegerType>(resultType.getDtype())) {
|
||||
return getSelf();
|
||||
}
|
||||
return {};
|
||||
|
@ -1892,9 +1891,9 @@ OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
||||
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||
if (resultType && resultType.hasDtype() &&
|
||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
||||
isa<mlir::IntegerType>(resultType.getDtype())) {
|
||||
return getSelf();
|
||||
}
|
||||
return {};
|
||||
|
@ -1905,9 +1904,9 @@ OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
||||
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||
if (resultType && resultType.hasDtype() &&
|
||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
||||
isa<mlir::IntegerType>(resultType.getDtype())) {
|
||||
return getSelf();
|
||||
}
|
||||
return {};
|
||||
|
@ -1918,7 +1917,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
||||
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||
if (resultType && resultType.hasDtype() &&
|
||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
||||
return getSelf();
|
||||
|
@ -1987,7 +1986,7 @@ void AtenDivScalarModeOp::getCanonicalizationPatterns(
|
|||
void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) {
|
||||
auto inputType = op.getSelf().getType().dyn_cast<BaseTensorType>();
|
||||
auto inputType = dyn_cast<BaseTensorType>(op.getSelf().getType());
|
||||
if (!inputType || !inputType.areAllSizesKnown()) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -2113,7 +2112,7 @@ traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
|
|||
if (!value || !value.getType().isa<BaseTensorType>())
|
||||
return failure();
|
||||
|
||||
auto tensorType = value.getType().cast<BaseTensorType>();
|
||||
auto tensorType = cast<BaseTensorType>(value.getType());
|
||||
if (foundType(tensorType, dim))
|
||||
return tensorType;
|
||||
|
||||
|
@ -2649,7 +2648,7 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
|
|||
.dyn_cast_or_null<ElementsAttr>();
|
||||
if (!attr)
|
||||
return failure();
|
||||
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
|
||||
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
||||
NonValueTensorType returnType =
|
||||
NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
||||
tensorType.getElementType());
|
||||
|
@ -2691,7 +2690,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
|||
.dyn_cast_or_null<ElementsAttr>();
|
||||
if (!attr)
|
||||
return failure();
|
||||
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
|
||||
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
||||
ValueTensorType returnType =
|
||||
ValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
||||
tensorType.getElementType());
|
||||
|
@ -2751,8 +2750,8 @@ void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult CopyToNonValueTensorOp::verify() {
|
||||
auto resultType = getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
||||
auto resultType = cast<BaseTensorType>(getResult().getType());
|
||||
auto operandType = cast<BaseTensorType>(getOperand().getType());
|
||||
if (!resultType.hasSameSizesAndDtype(operandType))
|
||||
return emitError() << "operand and result must have same sizes and dtype";
|
||||
return success();
|
||||
|
@ -2762,7 +2761,7 @@ LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
|
|||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto resultType = operands[0].getType().cast<ValueTensorType>();
|
||||
auto resultType = cast<ValueTensorType>(operands[0].getType());
|
||||
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
|
||||
return success();
|
||||
}
|
||||
|
@ -2778,8 +2777,8 @@ void CopyToNonValueTensorOp::getEffects(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult CopyToValueTensorOp::verify() {
|
||||
auto resultType = getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
||||
auto resultType = cast<BaseTensorType>(getResult().getType());
|
||||
auto operandType = cast<BaseTensorType>(getOperand().getType());
|
||||
if (!resultType.hasSameSizesAndDtype(operandType))
|
||||
return emitError() << "operand and result must have same sizes and dtype";
|
||||
return success();
|
||||
|
@ -2789,7 +2788,7 @@ LogicalResult CopyToValueTensorOp::inferReturnTypes(
|
|||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto resultType = operands[0].getType().cast<NonValueTensorType>();
|
||||
auto resultType = cast<NonValueTensorType>(operands[0].getType());
|
||||
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
|
||||
return success();
|
||||
}
|
||||
|
@ -3004,7 +3003,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
||||
auto operandType = getSelf().getType().dyn_cast<BaseTensorType>();
|
||||
auto operandType = dyn_cast<BaseTensorType>(getSelf().getType());
|
||||
if (!operandType)
|
||||
return nullptr;
|
||||
if (operandType.hasDtype()) {
|
||||
|
@ -3493,8 +3492,8 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||
auto inType = dyn_cast<BaseTensorType>(getOperand(0).getType());
|
||||
auto outType = dyn_cast<BaseTensorType>(getResult().getType());
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||
!outType.hasDtype())
|
||||
return nullptr;
|
||||
|
@ -3534,8 +3533,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
|
||||
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
|
||||
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
||||
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
||||
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
|
||||
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
|
||||
|
||||
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
||||
start.getValue().getSExtValue() == 0 &&
|
||||
|
@ -3793,7 +3792,7 @@ OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
|
||||
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(getA().getType());
|
||||
if (tensorType.hasDtype()) {
|
||||
torch_upstream::ScalarType scalarType =
|
||||
Torch::getScalarTypeForType(tensorType.getDtype());
|
||||
|
@ -4568,7 +4567,7 @@ LogicalResult AtenNormScalarOp::verify() {
|
|||
// Per PyTorch docs, only float and complex types are valid for norm
|
||||
// operation.
|
||||
|
||||
auto inTensor = getSelf().getType().cast<BaseTensorType>();
|
||||
auto inTensor = cast<BaseTensorType>(getSelf().getType());
|
||||
|
||||
// If no dtype is specified, it will default to a float one.
|
||||
if (!inTensor.hasDtype()) {
|
||||
|
@ -4605,8 +4604,8 @@ LogicalResult AtenPermuteOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
auto outType = getResult().getType().cast<BaseTensorType>();
|
||||
auto inType = getSelf().getType().cast<BaseTensorType>();
|
||||
auto outType = cast<BaseTensorType>(getResult().getType());
|
||||
auto inType = cast<BaseTensorType>(getSelf().getType());
|
||||
|
||||
if (!outType.hasSizes() || !inType.hasSizes()) {
|
||||
return success();
|
||||
|
@ -4689,8 +4688,8 @@ LogicalResult AtenPermuteOp::verify() {
|
|||
|
||||
LogicalResult AtenLinalgCrossOp::verify() {
|
||||
|
||||
auto selfType = getSelf().getType().cast<BaseTensorType>();
|
||||
auto otherType = getOther().getType().cast<BaseTensorType>();
|
||||
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
||||
auto otherType = cast<BaseTensorType>(getOther().getType());
|
||||
|
||||
if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() ||
|
||||
!otherType.hasSizes()) {
|
||||
|
@ -4857,7 +4856,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
|||
|
||||
// Check that initial values satisfy type bounds.
|
||||
for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) {
|
||||
auto symName = initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
||||
auto symName = cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
||||
auto initialValue = initialize.getOperand(i);
|
||||
auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue());
|
||||
if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) {
|
||||
|
|
|
@ -49,7 +49,7 @@ public:
|
|||
// The incoporation of the torch.type_bound arg attr is context-dependent.
|
||||
|
||||
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
||||
if (type.value().isa<NonValueTensorType>()) {
|
||||
if (isa<NonValueTensorType>(type.value())) {
|
||||
auto typeBoundAttr =
|
||||
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
|
||||
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
|
||||
|
@ -61,7 +61,7 @@ public:
|
|||
? typeBoundAttr.getValue()
|
||||
: type.value());
|
||||
continue;
|
||||
} else if (auto none = type.value().dyn_cast<Torch::NoneType>()) {
|
||||
} else if (auto none = dyn_cast<Torch::NoneType>(type.value())) {
|
||||
continue;
|
||||
}
|
||||
// TODO: add tuple type.
|
||||
|
@ -111,7 +111,7 @@ public:
|
|||
|
||||
SmallVector<Value> newOperands;
|
||||
for (auto operand : llvm::enumerate(adaptor.getOperands())) {
|
||||
if (operand.value().getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(operand.value().getType()))
|
||||
continue;
|
||||
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
|
||||
if (it != typeBoundMap.end()) {
|
||||
|
@ -167,9 +167,9 @@ public:
|
|||
for (auto operand : adaptor.getOperands()) {
|
||||
if (!operand)
|
||||
continue;
|
||||
if (operand.getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(operand.getType()))
|
||||
continue;
|
||||
if (auto tuple = operand.getType().dyn_cast<Torch::TupleType>()) {
|
||||
if (auto tuple = dyn_cast<Torch::TupleType>(operand.getType())) {
|
||||
Location loc = op.getLoc();
|
||||
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
|
||||
auto i = rewriter.create<ConstantIntOp>(
|
||||
|
@ -207,7 +207,7 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
|
|||
[](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs,
|
||||
Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<BaseTensorType>());
|
||||
assert(isa<BaseTensorType>(inputs[0].getType()));
|
||||
return copyTensorToType(builder, loc, type, inputs[0]);
|
||||
});
|
||||
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);
|
||||
|
|
|
@ -29,7 +29,7 @@ using namespace mlir::torch::Torch;
|
|||
|
||||
// Helper function to check whether the `dtype` is None or Float type.
|
||||
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
||||
if (dtype.getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(dtype.getType()))
|
||||
return true;
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
|
@ -87,7 +87,7 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||
Type resultType = computeReductionType(
|
||||
rewriter, op, input.getType().cast<BaseTensorType>(), dim, keepDim);
|
||||
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim);
|
||||
if (!resultType)
|
||||
return nullptr;
|
||||
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
|
||||
|
@ -100,7 +100,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
bool keepDim) {
|
||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||
BaseTensorType valueType =
|
||||
computeReductionType(rewriter, op, input.getType().cast<BaseTensorType>(),
|
||||
computeReductionType(rewriter, op, cast<BaseTensorType>(input.getType()),
|
||||
dim, keepDim)
|
||||
.cast<BaseTensorType>();
|
||||
if (!valueType)
|
||||
|
@ -296,7 +296,7 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
|||
int64_t contractingDimsLength,
|
||||
int64_t otherDimsLength,
|
||||
int64_t reduceDimsLength, bool isLhs) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
||||
reduceDimsLength;
|
||||
SmallVector<Value> inputShapeTensor;
|
||||
|
@ -415,7 +415,7 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
|||
SmallVector<char> &contractingDims,
|
||||
SmallVector<char> &otherDims,
|
||||
SmallVector<char> &reduceDims, bool isLhs) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
||||
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
||||
dimTokenMap[dimTokens[idx]] = idx;
|
||||
|
@ -451,8 +451,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|||
Value &result,
|
||||
SmallVector<char> &resultTokens,
|
||||
SmallVector<char> &finalResultTokens) {
|
||||
auto lhsType = lhs.getType().cast<BaseTensorType>();
|
||||
auto rhsType = rhs.getType().cast<BaseTensorType>();
|
||||
auto lhsType = cast<BaseTensorType>(lhs.getType());
|
||||
auto rhsType = cast<BaseTensorType>(rhs.getType());
|
||||
|
||||
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
||||
: rhsType.getOptionalDtype();
|
||||
|
@ -562,7 +562,7 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
|||
Value input,
|
||||
SmallVector<char> &inputTokens,
|
||||
SmallVector<char> &outTokens) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
|
||||
llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end());
|
||||
SmallVector<int64_t> sumDims;
|
||||
|
@ -643,7 +643,7 @@ public:
|
|||
op, "Expected a constant boolean value for keepDim");
|
||||
|
||||
Value input = op.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
|
||||
if (!inputTy || !inputTy.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Expected input type having sizes");
|
||||
|
@ -677,7 +677,7 @@ public:
|
|||
MLIRContext *context = op.getContext();
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes() || !inputType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
|
||||
}
|
||||
|
@ -764,7 +764,7 @@ public:
|
|||
Value dim = op.getDim();
|
||||
Value self = op.getSelf();
|
||||
|
||||
auto resultTy = op.getType().cast<BaseTensorType>();
|
||||
auto resultTy = cast<BaseTensorType>(op.getType());
|
||||
if (!resultTy.hasSizes() || !resultTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have sizes and dtype");
|
||||
|
@ -785,8 +785,8 @@ public:
|
|||
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
||||
Value slice = rewriter.create<AtenSliceTensorOp>(
|
||||
loc,
|
||||
computeReductionType(rewriter, op,
|
||||
self.getType().cast<BaseTensorType>(), dim,
|
||||
computeReductionType(rewriter, op, cast<BaseTensorType>(self.getType()),
|
||||
dim,
|
||||
/*keepDim=*/true),
|
||||
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
||||
|
||||
|
@ -988,7 +988,7 @@ public:
|
|||
Value self = op.getSelf();
|
||||
Value dim = op.getDim();
|
||||
|
||||
auto outputTy = op.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
auto outputTy = dyn_cast<Torch::ValueTensorType>(op.getType());
|
||||
if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected output type having sizes and dtype");
|
||||
|
@ -1069,7 +1069,7 @@ public:
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: m must be constant");
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
||||
auto outType = dyn_cast<BaseTensorType>(op.getType());
|
||||
if (!outType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
@ -1111,13 +1111,13 @@ public:
|
|||
|
||||
// compare unsqueezed input with boundaries
|
||||
auto eqType = ValueTensorType::get(
|
||||
context, op.getType().cast<BaseTensorType>().getSizes(),
|
||||
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
||||
IntegerType::get(context, 1));
|
||||
Value eqTensor =
|
||||
rewriter.create<AtenEqTensorOp>(loc, eqType, unsqzRangeN, rangeM);
|
||||
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::BoolType>()) {
|
||||
if (isa<Torch::BoolType>(dtype.getType())) {
|
||||
rewriter.replaceOp(op, eqTensor);
|
||||
return success();
|
||||
} else {
|
||||
|
@ -1210,7 +1210,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.getSelf();
|
||||
// TODO: Handle non value tensor type operands.
|
||||
if (!input.getType().isa<ValueTensorType>()) {
|
||||
if (!isa<ValueTensorType>(input.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only value tensor type operands are supported");
|
||||
}
|
||||
|
@ -1248,7 +1248,7 @@ public:
|
|||
}
|
||||
|
||||
auto allTensorHasSizes = [](Value tensor) {
|
||||
auto type = tensor.getType().dyn_cast<BaseTensorType>();
|
||||
auto type = dyn_cast<BaseTensorType>(tensor.getType());
|
||||
if (!type || !type.hasSizes())
|
||||
return false;
|
||||
return true;
|
||||
|
@ -1267,7 +1267,7 @@ public:
|
|||
if (equation.find("...") != std::string::npos) {
|
||||
SmallVector<int64_t> inputRanks;
|
||||
for (Value tensor : inputTensors) {
|
||||
auto type = tensor.getType().cast<BaseTensorType>();
|
||||
auto type = cast<BaseTensorType>(tensor.getType());
|
||||
inputRanks.push_back(type.getSizes().size());
|
||||
}
|
||||
|
||||
|
@ -1332,10 +1332,10 @@ public:
|
|||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
BaseTensorType inputType = self.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(self.getType());
|
||||
|
||||
Value output = op.getResult();
|
||||
BaseTensorType outputType = output.getType().cast<BaseTensorType>();
|
||||
BaseTensorType outputType = cast<BaseTensorType>(output.getType());
|
||||
|
||||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||
int64_t diagonalSize = std::min(inputShape[0], inputShape[1]);
|
||||
|
@ -1399,7 +1399,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType());
|
||||
if (!resultTensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
|
@ -1410,7 +1410,7 @@ public:
|
|||
"Only support floating-point type");
|
||||
|
||||
// If `dtype` arg is non-none then convert the input to `dtype`.
|
||||
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||
Location loc = op.getLoc();
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
|
@ -1440,15 +1440,15 @@ public:
|
|||
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
|
||||
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
bool halfToFloat;
|
||||
if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected a boolean value for half_to_float");
|
||||
|
||||
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType());
|
||||
if (!resultTensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
|
@ -1500,8 +1500,8 @@ public:
|
|||
Value output = op.getOutput();
|
||||
Value dim = op.getDim();
|
||||
|
||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
|
||||
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
Value newGrad =
|
||||
|
@ -1536,8 +1536,8 @@ public:
|
|||
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
||||
Value output = op.getOutput();
|
||||
|
||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
|
||||
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
Value tanhSquare =
|
||||
|
@ -1567,8 +1567,8 @@ public:
|
|||
Value output = op.getOutput();
|
||||
Value dim = op.getDim();
|
||||
|
||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
|
||||
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
||||
|
@ -1650,8 +1650,8 @@ public:
|
|||
Value keepDim = op.getKeepdim();
|
||||
Value result = op.getResult();
|
||||
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
BaseTensorType indicesTensorType = cast<BaseTensorType>(result.getType());
|
||||
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
||||
if (!maybeInputRank) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1670,7 +1670,7 @@ public:
|
|||
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
|
||||
// first the input tensor is flattened to 1d tensor and then the reduction
|
||||
// happens on the 0th dimension.
|
||||
if (dim.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(dim.getType())) {
|
||||
BaseTensorType flattenType =
|
||||
inputType
|
||||
.getWithSizesAndDtype({kUnknownSize},
|
||||
|
@ -1720,7 +1720,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
|
||||
Value input = op.getSelf();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input must have known sizes");
|
||||
|
@ -1728,7 +1728,7 @@ public:
|
|||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||
|
||||
Value boundaries = op.getBoundaries();
|
||||
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
|
||||
auto boundariesType = cast<BaseTensorType>(boundaries.getType());
|
||||
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: boundaries must have "
|
||||
|
@ -1827,7 +1827,7 @@ static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) {
|
|||
Location loc = op.getLoc();
|
||||
Value dim = op.getDim();
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
|
||||
Value xMax =
|
||||
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
||||
if (!xMax)
|
||||
|
@ -1856,12 +1856,12 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.getSelf();
|
||||
if (!op.getDtype().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getDtype().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented non-None dtype for log_softmax");
|
||||
|
||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
|
||||
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
Value logSoftmax = getLogSoftmaxResult(op, rewriter);
|
||||
|
@ -1974,7 +1974,7 @@ public:
|
|||
Type opType = op.getType();
|
||||
Value dim = op.getDim();
|
||||
|
||||
auto resType = self.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(self.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -2088,7 +2088,7 @@ public:
|
|||
|
||||
Location loc = op.getLoc();
|
||||
Value inValue = op.getSelf();
|
||||
auto inType = inValue.getType().cast<BaseTensorType>();
|
||||
auto inType = cast<BaseTensorType>(inValue.getType());
|
||||
auto maybeSizes = inType.getOptionalSizes();
|
||||
if (!maybeSizes) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -2234,7 +2234,7 @@ public:
|
|||
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
||||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
||||
Value input) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
|
||||
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
||||
Value cst6 =
|
||||
|
@ -2252,7 +2252,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -2304,7 +2304,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value negativeSlope = op.getNegativeSlope();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -2341,7 +2341,7 @@ public:
|
|||
Value gradOutput = op.getGradOutput();
|
||||
Value input = op.getSelf();
|
||||
Value negativeSlope = op.getNegativeSlope();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -2382,7 +2382,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value weight = op.getWeight();
|
||||
auto resType = op.getType().cast<ValueTensorType>();
|
||||
auto resType = cast<ValueTensorType>(op.getType());
|
||||
auto boolTensorType = rewriter.getType<ValueTensorType>(
|
||||
resType.getOptionalSizes(), rewriter.getI1Type());
|
||||
Value zero =
|
||||
|
@ -2408,14 +2408,14 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenLerpScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
auto start = op.getSelf();
|
||||
auto inputType = start.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(start.getType());
|
||||
|
||||
auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
|
||||
start, cstOne);
|
||||
|
@ -2442,7 +2442,7 @@ public:
|
|||
Value alpha = op.getAlpha();
|
||||
Value scale = op.getScale();
|
||||
Value inputScale = op.getInputScale();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -2486,7 +2486,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -2578,7 +2578,7 @@ public:
|
|||
}
|
||||
// Ensure all tensors have known sizes
|
||||
for (Value tensor : tensors) {
|
||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
|
||||
if (!tensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: one tensor does not have known sizes");
|
||||
|
@ -2596,7 +2596,8 @@ public:
|
|||
}
|
||||
|
||||
Type listElemType =
|
||||
op.getType().cast<BaseTensorType>().getWithSizesAndDtype(
|
||||
cast<BaseTensorType>(op.getType())
|
||||
.getWithSizesAndDtype(
|
||||
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
|
||||
|
@ -2635,7 +2636,7 @@ public:
|
|||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
auto self = op.getSelf();
|
||||
auto selfTy = self.getType().cast<BaseTensorType>();
|
||||
auto selfTy = cast<BaseTensorType>(self.getType());
|
||||
// roll(input, shift, dim) = cat({
|
||||
// slice(input, dim, -shift, none),
|
||||
// slice(input, dim, 0, -shift)}, dim)
|
||||
|
@ -2817,7 +2818,7 @@ public:
|
|||
if (!selfTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: no implementation for rankless tensor");
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: no implementation for rankless tensor");
|
||||
|
@ -2968,7 +2969,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
MLIRContext *context = op.getContext();
|
||||
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType outputTensorType = cast<BaseTensorType>(op.getType());
|
||||
if (!outputTensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: output must have known sizes");
|
||||
|
@ -2977,7 +2978,7 @@ public:
|
|||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
||||
unsigned inputRank = *maybeRank;
|
||||
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
|
||||
auto inputTensorType = cast<Torch::ValueTensorType>(self.getType());
|
||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Expected input type having sizes");
|
||||
|
@ -3077,7 +3078,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenWhereScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -3100,7 +3101,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -3122,7 +3123,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -3186,7 +3187,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -3227,7 +3228,7 @@ static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter,
|
|||
int64_t dimB,
|
||||
Value &transposed) {
|
||||
Type transposedType;
|
||||
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
|
||||
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
|
||||
dimA, dimB, transposedType)))
|
||||
return failure();
|
||||
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -3578,7 +3579,7 @@ public:
|
|||
op.getGroups(), op.getDilation());
|
||||
|
||||
Type transposedType;
|
||||
if (failed(getTransposedType(input.getType().cast<BaseTensorType>(), 0, 1,
|
||||
if (failed(getTransposedType(cast<BaseTensorType>(input.getType()), 0, 1,
|
||||
transposedType)))
|
||||
return failure();
|
||||
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
|
@ -3605,7 +3606,7 @@ public:
|
|||
ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2],
|
||||
gradOutputSize[3]});
|
||||
|
||||
BaseTensorType gradOutputTy = gradOutput.getType().cast<BaseTensorType>();
|
||||
BaseTensorType gradOutputTy = cast<BaseTensorType>(gradOutput.getType());
|
||||
if (!gradOutputTy.hasSizes())
|
||||
return failure();
|
||||
SmallVector<int64_t> gradOutputSizesInt(gradOutputTy.getSizes());
|
||||
|
@ -3625,7 +3626,7 @@ public:
|
|||
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
|
||||
|
||||
BaseTensorType inputTransposedTy =
|
||||
inputTransposed.getType().cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(inputTransposed.getType());
|
||||
if (!inputTransposedTy.hasSizes())
|
||||
return failure();
|
||||
SmallVector<int64_t> inputTransposedSizesInt(
|
||||
|
@ -3660,7 +3661,7 @@ public:
|
|||
/*dilation=*/op.getStride(), op.getTransposed(),
|
||||
op.getOutputPadding(), numGroup);
|
||||
|
||||
BaseTensorType weightTy = weight.getType().cast<BaseTensorType>();
|
||||
BaseTensorType weightTy = cast<BaseTensorType>(weight.getType());
|
||||
if (!weightTy.hasSizes())
|
||||
return failure();
|
||||
SmallVector<int64_t> weightSizes(weightTy.getSizes());
|
||||
|
@ -3707,7 +3708,7 @@ public:
|
|||
gradWeight = rewriter.create<Torch::AtenViewOp>(
|
||||
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
|
||||
|
||||
gradWeightTy = gradWeight.getType().cast<BaseTensorType>();
|
||||
gradWeightTy = cast<BaseTensorType>(gradWeight.getType());
|
||||
SmallVector<int64_t, 5> gradWeightDimsOrder =
|
||||
computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size());
|
||||
SmallVector<int64_t, 5> gradWeightMoveDimShape;
|
||||
|
@ -3733,7 +3734,7 @@ public:
|
|||
/*keepdim=*/cstFalse,
|
||||
/*dtype=*/cstNone);
|
||||
} else {
|
||||
if (failed(getTransposedType(gradOutput.getType().cast<BaseTensorType>(),
|
||||
if (failed(getTransposedType(cast<BaseTensorType>(gradOutput.getType()),
|
||||
0, 1, transposedType)))
|
||||
return failure();
|
||||
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
|
@ -3792,7 +3793,7 @@ public:
|
|||
}
|
||||
|
||||
// TODO: Handle integer type operands.
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non-floating point dtype");
|
||||
|
@ -3821,7 +3822,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value output = op.getResult();
|
||||
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
|
||||
BaseTensorType outputTensorType = cast<BaseTensorType>(output.getType());
|
||||
Value sum =
|
||||
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.getDtype());
|
||||
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
||||
|
@ -3854,7 +3855,7 @@ public:
|
|||
Type outputType = op.getType();
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
||||
!isNoneOrFloatDtype(context, dtype)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -3944,7 +3945,7 @@ public:
|
|||
rewriter.replaceOp(op, input);
|
||||
return success();
|
||||
}
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support floating type input for training mode");
|
||||
|
@ -3992,7 +3993,7 @@ public:
|
|||
rewriter.replaceOp(op, ArrayRef<Value>{input, trueMask});
|
||||
return success();
|
||||
}
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support floating type input for training mode");
|
||||
|
@ -4029,7 +4030,7 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
||||
}
|
||||
unsigned inputRank = *maybeInputRank;
|
||||
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType rank0FloatTensorTy = cast<BaseTensorType>(op.getType());
|
||||
if (!rank0FloatTensorTy.hasSizes() ||
|
||||
rank0FloatTensorTy.getSizes().size() != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4060,7 +4061,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenStdOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -4084,7 +4085,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
|
||||
Value inputTimesBeta =
|
||||
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.getBeta());
|
||||
|
@ -4116,7 +4117,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenStdDimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
||||
if (!inputTensorType.hasDtype() ||
|
||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4141,7 +4142,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenStdCorrectionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
||||
if (!inputTensorType.hasDtype() ||
|
||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4167,8 +4168,8 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -4208,8 +4209,8 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
@ -4235,7 +4236,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Type resultType = op.getType();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support floating-point type");
|
||||
|
@ -4268,8 +4269,8 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
|||
Operation *op, Location loc,
|
||||
Value input, Value prob,
|
||||
Value &output) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto probType = prob.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
auto probType = cast<BaseTensorType>(prob.getType());
|
||||
// Both the `input` and `prob` must be ranked tensors.
|
||||
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
|
||||
!probType.hasDtype()) {
|
||||
|
@ -4338,12 +4339,12 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value p = op.getP();
|
||||
if (!op.getGenerator().getType().template isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
SmallVector<int64_t> empty;
|
||||
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
|
||||
rewriter.getF64Type());
|
||||
|
@ -4485,7 +4486,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
auto input = op.getInput().getType().cast<BaseTensorType>();
|
||||
auto input = cast<BaseTensorType>(op.getInput().getType());
|
||||
if (!input.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor should have known sizes.");
|
||||
|
@ -4518,7 +4519,7 @@ class DecomposeAtenInstanceNormOp
|
|||
Location loc = op.getLoc();
|
||||
auto context = op.getContext();
|
||||
|
||||
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
||||
auto inputTy = cast<BaseTensorType>(op.getInput().getType());
|
||||
int64_t inputRank = inputTy.getSizes().size();
|
||||
SmallVector<int64_t> reducedShape(inputTy.getSizes());
|
||||
SmallVector<int64_t> reduceDimInts;
|
||||
|
@ -4583,7 +4584,7 @@ class DecomposeAtenInstanceNormOp
|
|||
loc, op.getResult().getType(), inputNormalized);
|
||||
|
||||
Value weight = op.getWeight();
|
||||
auto weightTy = weight.getType().cast<BaseTensorType>();
|
||||
auto weightTy = cast<BaseTensorType>(weight.getType());
|
||||
dtype = weightTy.getOptionalDtype();
|
||||
|
||||
SmallVector<int64_t> weightShape(weightTy.getSizes());
|
||||
|
@ -4610,7 +4611,7 @@ class DecomposeAtenInstanceNormOp
|
|||
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
|
||||
|
||||
Value bias = op.getBias();
|
||||
auto biasTy = bias.getType().cast<BaseTensorType>();
|
||||
auto biasTy = cast<BaseTensorType>(bias.getType());
|
||||
dtype = biasTy.getOptionalDtype();
|
||||
|
||||
SmallVector<int64_t> biasShape(biasTy.getSizes());
|
||||
|
@ -4654,7 +4655,7 @@ class DecomposeAtenNativeLayerNormOp
|
|||
Location loc = op.getLoc();
|
||||
auto context = op.getContext();
|
||||
|
||||
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
||||
auto inputTy = cast<BaseTensorType>(op.getInput().getType());
|
||||
if (!inputTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor should have known sizes.");
|
||||
|
@ -4889,10 +4890,10 @@ class DecomposeAtenNativeGroupNormOp
|
|||
Value eps = op.getEps();
|
||||
|
||||
// Check the rank of the input/outputs tensor.
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto outputType = op.getResult0().getType().cast<BaseTensorType>();
|
||||
auto meanType = op.getResult1().getType().cast<BaseTensorType>();
|
||||
auto rsqrtVarType = op.getResult2().getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
auto outputType = cast<BaseTensorType>(op.getResult0().getType());
|
||||
auto meanType = cast<BaseTensorType>(op.getResult1().getType());
|
||||
auto rsqrtVarType = cast<BaseTensorType>(op.getResult2().getType());
|
||||
if (!inputType.hasSizes() || !outputType.hasSizes() ||
|
||||
!meanType.hasSizes() || !rsqrtVarType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -5059,8 +5060,8 @@ class DecomposeAtenNativeBatchNormOp
|
|||
|
||||
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
||||
runningStatsShapeInt[1] =
|
||||
runningMean.getType().cast<BaseTensorType>().getSizes()[0];
|
||||
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
cast<BaseTensorType>(runningMean.getType()).getSizes()[0];
|
||||
Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype();
|
||||
Type reshapeType = ValueTensorType::get(
|
||||
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
||||
|
||||
|
@ -5175,8 +5176,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType tensorType =
|
||||
op.getSelf().getType().template cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input tensor to have a dtype");
|
||||
|
@ -5200,7 +5200,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenFullOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
||||
BaseTensorType outTy = cast<BaseTensorType>(op.getType());
|
||||
if (!outTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
|
@ -5231,12 +5231,12 @@ public:
|
|||
Value weight = op.getWeight();
|
||||
Value bias = op.getBias();
|
||||
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input to be rank 2 or greater");
|
||||
|
||||
BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
|
||||
BaseTensorType weightType = cast<BaseTensorType>(weight.getType());
|
||||
// `weight` must be a rank 2 matrix.
|
||||
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
|
||||
|
@ -5255,7 +5255,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
|
||||
BaseTensorType biasType = cast<BaseTensorType>(bias.getType());
|
||||
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
||||
|
||||
|
@ -5280,7 +5280,7 @@ public:
|
|||
Value input = op.getSelf();
|
||||
Type type = op.getType();
|
||||
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype())
|
||||
return rewriter.notifyMatchFailure(op, "Dtype not present");
|
||||
|
||||
|
@ -5306,7 +5306,7 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenFullLikeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
||||
BaseTensorType outTy = cast<BaseTensorType>(op.getType());
|
||||
if (!outTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
|
@ -5335,7 +5335,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input tensor to have a dtype");
|
||||
|
@ -5393,7 +5393,7 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto resultType = op.getType().cast<BaseTensorType>();
|
||||
auto resultType = cast<BaseTensorType>(op.getType());
|
||||
if (!resultType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
|
@ -5419,12 +5419,12 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenCopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto resultType = op.getType().cast<BaseTensorType>();
|
||||
auto resultType = cast<BaseTensorType>(op.getType());
|
||||
if (!resultType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
}
|
||||
auto srcTy = op.getSrc().getType().cast<BaseTensorType>();
|
||||
auto srcTy = cast<BaseTensorType>(op.getSrc().getType());
|
||||
if (!srcTy.hasSizes() || !srcTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected src type to have a known rank and dtype");
|
||||
|
@ -5448,7 +5448,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
|||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input tensor to have a dtype");
|
||||
|
@ -5588,7 +5588,7 @@ public:
|
|||
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().template isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(dtype.getType())) {
|
||||
dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf());
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
||||
|
@ -5665,7 +5665,7 @@ class DecomposeAtenAdaptiveAvgPool1dOp
|
|||
|
||||
SmallVector<Value, 1> kernelSize;
|
||||
if (outputSizeInt == 1) {
|
||||
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
kernelSize.push_back(
|
||||
inputShape[rank - 1] == kUnknownSize
|
||||
|
@ -5839,7 +5839,7 @@ class DecomposeAtenCosineSimilarityOp
|
|||
SmallVector<Value> indexBroadcastShapeValue;
|
||||
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
|
||||
indexBroadcastShapeValue);
|
||||
Type dtype = x1.getType().cast<BaseTensorType>().getOptionalDtype();
|
||||
Type dtype = cast<BaseTensorType>(x1.getType()).getOptionalDtype();
|
||||
Type broadcastType = ValueTensorType::get(
|
||||
op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype);
|
||||
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
||||
|
@ -5925,9 +5925,9 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
|
|||
Value alphaTimesBmm =
|
||||
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
|
||||
Value input = op.getSelf();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
BaseTensorType resultType =
|
||||
op->getResult(0).getType().cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(op->getResult(0).getType());
|
||||
if (inputType.hasDtype() && resultType.hasDtype() &&
|
||||
inputType.getDtype() != resultType.getDtype()) {
|
||||
input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype());
|
||||
|
@ -6011,7 +6011,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
Value self = op.getSelf();
|
||||
Value dimList = op.getDim();
|
||||
Value keepDim = op.getKeepdim();
|
||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
|
||||
Type outputType = op.getType();
|
||||
BaseTensorType outputTensorType = cast<BaseTensorType>(outputType);
|
||||
if (!outputTensorType.hasDtype()) {
|
||||
|
@ -6030,7 +6030,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
// computation of the result.
|
||||
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
|
||||
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
|
||||
inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
inputTensorTy = cast<BaseTensorType>(self.getType());
|
||||
}
|
||||
|
||||
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
||||
|
@ -6040,7 +6040,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
unsigned inputRank = *maybeInputRank;
|
||||
SmallVector<Value> dimListElements;
|
||||
bool isNoneOrEmpty = true;
|
||||
if (!dimList.getType().template isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(dimList.getType())) {
|
||||
if (!getListConstructElements(dimList, dimListElements))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expect dimList to be constructed from list construct");
|
||||
|
@ -6287,8 +6287,8 @@ public:
|
|||
op, "Expected a constant integer value for reduction");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
BaseTensorType resultType = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = op.getSelf().getType().cast<BaseTensorType>();
|
||||
BaseTensorType resultType = cast<BaseTensorType>(op.getType());
|
||||
BaseTensorType inputType = cast<BaseTensorType>(op.getSelf().getType());
|
||||
if (!inputType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected the input tensor to have sizes");
|
||||
|
@ -6506,7 +6506,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resultType = op.getType().cast<BaseTensorType>();
|
||||
auto resultType = cast<BaseTensorType>(op.getType());
|
||||
|
||||
if (!resultType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -6617,7 +6617,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto resultType = op.getType().cast<BaseTensorType>();
|
||||
auto resultType = cast<BaseTensorType>(op.getType());
|
||||
|
||||
if (!resultType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -6943,7 +6943,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|||
auto context = op.getContext();
|
||||
|
||||
Value input = op.getSelf();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor should have known sizes.");
|
||||
|
@ -6974,7 +6974,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|||
|
||||
// compare
|
||||
auto eqType = ValueTensorType::get(
|
||||
context, op.getType().cast<BaseTensorType>().getSizes(),
|
||||
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
||||
IntegerType::get(context, 1));
|
||||
Value eqTensor = rewriter.create<AtenEqTensorOp>(
|
||||
loc, eqType, unsqueezeTensor, arangeTensor);
|
||||
|
@ -7019,7 +7019,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenScalarTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
auto resultTy = op.getResult().getType().cast<BaseTensorType>();
|
||||
auto resultTy = cast<BaseTensorType>(op.getResult().getType());
|
||||
auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType());
|
||||
Value numToTensor = rewriter.create<PrimNumToTensorScalarOp>(
|
||||
op.getLoc(),
|
||||
|
@ -7060,7 +7060,7 @@ public:
|
|||
|
||||
Value self = op.getSelf();
|
||||
Value dim = op.getDim();
|
||||
auto selfType = self.getType().cast<BaseTensorType>();
|
||||
auto selfType = cast<BaseTensorType>(self.getType());
|
||||
auto sortIndicesType = selfType.getWithSizesAndDtype(
|
||||
selfType.getOptionalSizes(),
|
||||
IntegerType::get(context, 64, IntegerType::Signed));
|
||||
|
@ -7111,8 +7111,8 @@ public:
|
|||
Value sizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), sizes);
|
||||
|
||||
auto selfType = self.getType().cast<BaseTensorType>();
|
||||
auto indexType = index.getType().cast<BaseTensorType>();
|
||||
auto selfType = cast<BaseTensorType>(self.getType());
|
||||
auto indexType = cast<BaseTensorType>(index.getType());
|
||||
BaseTensorType srcType =
|
||||
selfType
|
||||
.getWithSizesAndDtype(indexType.getOptionalSizes(),
|
||||
|
@ -7135,7 +7135,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenSgnOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto outType = op.getType().cast<BaseTensorType>();
|
||||
auto outType = cast<BaseTensorType>(op.getType());
|
||||
if (!outType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"expected result type to have dtype");
|
||||
|
@ -7273,14 +7273,14 @@ public:
|
|||
"failed to get elements of `indices`");
|
||||
|
||||
auto input = op.getSelf();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only input with shape information is supported");
|
||||
}
|
||||
auto inputSizes = inputType.getSizes();
|
||||
int64_t inputRank = inputSizes.size();
|
||||
auto outputType = op.getType().cast<BaseTensorType>();
|
||||
auto outputType = cast<BaseTensorType>(op.getType());
|
||||
if (!outputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only output with shape information is supported");
|
||||
|
@ -7438,7 +7438,7 @@ public:
|
|||
op, "failed to get elements of `dims` param");
|
||||
}
|
||||
auto dimsSize = dimsElements.size();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support input tensor with shape information");
|
||||
|
|
|
@ -89,7 +89,7 @@ public:
|
|||
.cast<ValueTensorType>()
|
||||
.getOptionalDtype();
|
||||
auto torchQType =
|
||||
quant.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
cast<ValueTensorType>(quant.getType()).getOptionalDtype();
|
||||
auto transQTy =
|
||||
rewriter.getType<ValueTensorType>(trans.getResult()
|
||||
.getType()
|
||||
|
@ -152,7 +152,7 @@ public:
|
|||
return failure();
|
||||
|
||||
Value bias = operands[2];
|
||||
auto biasTy = bias.getType().dyn_cast<ValueTensorType>();
|
||||
auto biasTy = dyn_cast<ValueTensorType>(bias.getType());
|
||||
|
||||
if (biasTy) {
|
||||
auto biasETy = biasTy.getOptionalDtype();
|
||||
|
|
|
@ -134,7 +134,7 @@ private:
|
|||
slotName = setAttrOp.getName();
|
||||
}
|
||||
|
||||
auto moduleType = module.getType().cast<NnModuleType>();
|
||||
auto moduleType = cast<NnModuleType>(module.getType());
|
||||
auto slots = moduleClassNameToSlots.find(moduleType.getClassName());
|
||||
// TODO: Improve verifier so that this can never happen
|
||||
if (slots == moduleClassNameToSlots.end())
|
||||
|
@ -163,13 +163,13 @@ private:
|
|||
}
|
||||
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||
nnModule.getType().cast<NnModuleType>().getClassName());
|
||||
cast<NnModuleType>(nnModule.getType()).getClassName());
|
||||
for (auto t :
|
||||
llvm::zip(nnModule.getOps<SlotOp>(), classType.getOps<AttrOp>())) {
|
||||
auto slot = std::get<0>(t);
|
||||
auto attr = std::get<1>(t);
|
||||
nameStack.push_back(attr.getName().str());
|
||||
if (attr.getType().isa<NnModuleType>()) {
|
||||
if (isa<NnModuleType>(attr.getType())) {
|
||||
if (failed(recursivelyTraverse(
|
||||
slot.getValue().getDefiningOp<NnModuleOp>())))
|
||||
return failure();
|
||||
|
@ -333,7 +333,7 @@ static LogicalResult analyzeInstances(func::FuncOp func,
|
|||
for (auto &argInstance : argInstances)
|
||||
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
|
||||
auto walkResult = func.walk([&](PrimGetAttrOp op) {
|
||||
if (!op.getType().isa<NnModuleType>())
|
||||
if (!isa<NnModuleType>(op.getType()))
|
||||
return WalkResult::advance();
|
||||
auto instance = mapping.lookupOrNull(op.getReceiver());
|
||||
assert(instance && "verifyFuncConformsToSubset should ensure this");
|
||||
|
@ -355,7 +355,7 @@ createMonomorphizationForCall(func::CallOp op, IRMapping &mapping,
|
|||
Monomorphization monomorphization;
|
||||
monomorphization.func = func;
|
||||
for (auto operand : llvm::enumerate(op->getOperands())) {
|
||||
if (!operand.value().getType().isa<NnModuleType>())
|
||||
if (!isa<NnModuleType>(operand.value().getType()))
|
||||
continue;
|
||||
Value instance = mapping.lookupOrNull(operand.value());
|
||||
assert(instance && "verifyFuncConformsToSubset should ensure this");
|
||||
|
@ -377,7 +377,7 @@ public:
|
|||
monomorphization.func = func;
|
||||
bool canTriviallyMonomorphize = true;
|
||||
for (auto arg : llvm::enumerate(func.getArguments())) {
|
||||
auto type = arg.value().getType().dyn_cast<NnModuleType>();
|
||||
auto type = dyn_cast<NnModuleType>(arg.value().getType());
|
||||
if (!type)
|
||||
continue;
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(type.getClassName());
|
||||
|
@ -436,7 +436,7 @@ private:
|
|||
// !torch.nn.Module<"..."> types.
|
||||
static LogicalResult verifyNnModuleValueUses(Value value) {
|
||||
// Trivially succeed for non-module types.
|
||||
if (!value.getType().isa<NnModuleType>())
|
||||
if (!isa<NnModuleType>(value.getType()))
|
||||
return success();
|
||||
for (Operation *op : value.getUsers()) {
|
||||
if (isa<func::CallOp, PrimGetAttrOp>(op))
|
||||
|
@ -516,7 +516,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
|
|||
return WalkResult::advance();
|
||||
};
|
||||
auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
|
||||
if (!op.getType().isa<NnModuleType>()) {
|
||||
if (!isa<NnModuleType>(op.getType())) {
|
||||
auto instance =
|
||||
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
|
||||
SlotOp affectedSlot;
|
||||
|
@ -540,7 +540,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
|
|||
Monomorphization monomorphization = std::move(*maybeMonomorphization);
|
||||
auto newArguments = llvm::to_vector<6>(
|
||||
llvm::make_filter_range(op->getOperands(), [](Value v) {
|
||||
return !v.getType().isa<NnModuleType>();
|
||||
return !isa<NnModuleType>(v.getType());
|
||||
}));
|
||||
assert(newFuncs.find(monomorphization) != newFuncs.end());
|
||||
auto newOp = OpBuilder(op).create<func::CallOp>(
|
||||
|
@ -564,7 +564,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
|
|||
}
|
||||
llvm::BitVector argsToErase(func.getNumArguments());
|
||||
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
||||
if (type.value().isa<NnModuleType>()) {
|
||||
if (isa<NnModuleType>(type.value())) {
|
||||
argsToErase.set(type.index());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -248,8 +248,8 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
|
|||
}))
|
||||
continue;
|
||||
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||
auto symName = initialize.getSlotSymNames()[use.getOperandNumber()]
|
||||
.cast<FlatSymbolRefAttr>();
|
||||
auto symName = cast<FlatSymbolRefAttr>(
|
||||
initialize.getSlotSymNames()[use.getOperandNumber()]);
|
||||
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
|
||||
if (state->isSafe)
|
||||
|
@ -333,10 +333,10 @@ class InlineGlobalSlotsPass
|
|||
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
|
||||
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
|
||||
auto slotSymName =
|
||||
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
||||
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
||||
Value operand = initialize.getOperand(i);
|
||||
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>());
|
||||
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]));
|
||||
auto *state =
|
||||
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
|
||||
// We roll the analysis of whether a slot is set or public into the
|
||||
|
@ -408,7 +408,7 @@ class InlineGlobalSlotsPass
|
|||
SmallVector<Value> newInitialValues;
|
||||
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
|
||||
auto slotSymName =
|
||||
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
||||
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
||||
if (!safeToInline.count(slotSymName)) {
|
||||
newSlotSymNames.push_back(slotSymName);
|
||||
newInitialValues.push_back(initialize.getOperand(i));
|
||||
|
|
|
@ -118,7 +118,7 @@ static LogicalResult checkType(Operation *op, Type type,
|
|||
if (auto optionalType = dyn_cast<OptionalType>(type)) {
|
||||
// TODO: Be stricter about tensor types.
|
||||
// See comment below for ListType.
|
||||
if (optionalType.getContainedType().isa<ValueTensorType>())
|
||||
if (isa<ValueTensorType>(optionalType.getContainedType()))
|
||||
return success();
|
||||
return checkType(op, optionalType.getContainedType(),
|
||||
actuallyEmitDiagnostics);
|
||||
|
@ -134,7 +134,7 @@ static LogicalResult checkType(Operation *op, Type type,
|
|||
// the contained type information. Somehow this slips through and works.
|
||||
// We should be stricter about this and properly infer the contained type
|
||||
// and shape.
|
||||
if (listType.getContainedType().isa<ValueTensorType>())
|
||||
if (isa<ValueTensorType>(listType.getContainedType()))
|
||||
return success();
|
||||
return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics);
|
||||
}
|
||||
|
@ -535,7 +535,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
}
|
||||
target.addDynamicallyLegalOp<OperatorOp>(
|
||||
[backendLegalOpsSet](OperatorOp opOp) {
|
||||
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
|
||||
auto opName = cast<StringAttr>(opOp->getAttr("name")).getValue();
|
||||
return backendLegalOpsSet.contains(opName);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ public:
|
|||
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
|
||||
op.getOperand(3), op.getOperand(4));
|
||||
|
||||
auto clampTy = clamp.getType().cast<Torch::ValueTensorType>();
|
||||
auto clampTy = cast<Torch::ValueTensorType>(clamp.getType());
|
||||
if (!clampTy.hasDtype())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"dequantization has unknown dtype");
|
||||
|
|
|
@ -23,7 +23,7 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::Torch;
|
||||
|
||||
static Value assertNonValueTensor(Value tensor) {
|
||||
assert(tensor.getType().isa<NonValueTensorType>() &&
|
||||
assert(isa<NonValueTensorType>(tensor.getType()) &&
|
||||
"tensor is expected to be a non-value tensor");
|
||||
return tensor;
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ public:
|
|||
// to use value semantics (which happens for example with ops
|
||||
// that take two aliases as input), then it is possible that the
|
||||
// op no longer generates an alias.
|
||||
if (userResult.getType().isa<NonValueTensorType>())
|
||||
if (isa<NonValueTensorType>(userResult.getType()))
|
||||
availableAliases.insert(userResult);
|
||||
result.viewLikeOps.push_back(user);
|
||||
} else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
||||
|
@ -177,7 +177,7 @@ public:
|
|||
for (Operation *viewLikeOp : ops.viewLikeOps) {
|
||||
rewriter.modifyOpInPlace(viewLikeOp, [&] {
|
||||
Value result = viewLikeOp->getResult(0);
|
||||
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
|
||||
auto resultType = dyn_cast<NonValueTensorType>(result.getType());
|
||||
if (resultType)
|
||||
result.setType(resultType.getWithValueSemantics());
|
||||
});
|
||||
|
@ -230,7 +230,7 @@ public:
|
|||
if (isViewLikeOp(op)) {
|
||||
// We currently only support view-like ops with one tensor output.
|
||||
if (op->getNumResults() != 1 ||
|
||||
!op->getResult(0).getType().isa<BaseTensorType>()) {
|
||||
!isa<BaseTensorType>(op->getResult(0).getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
copy, "unsupported: view-like ops must have one tensor output, "
|
||||
"and the tensor output must be the first result");
|
||||
|
@ -242,7 +242,7 @@ public:
|
|||
// non-value tensor and the output being a value tensor. If this is the
|
||||
// case then there is no need to look at the users of the result of the
|
||||
// op.
|
||||
if (opResult.getType().isa<NonValueTensorType>()) {
|
||||
if (isa<NonValueTensorType>(opResult.getType())) {
|
||||
if (operand.getOperandNumber() == 0) {
|
||||
validViewLikeOps.insert(op);
|
||||
llvm::append_range(workList, opResult.getUses());
|
||||
|
@ -339,7 +339,7 @@ public:
|
|||
for (Operation *op : viewLikeOps) {
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
if (auto nonValueTensorType =
|
||||
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
|
||||
dyn_cast<NonValueTensorType>(op->getResult(0).getType())) {
|
||||
originalTypes[op->getResult(0)] = nonValueTensorType;
|
||||
op->getResult(0).setType(nonValueTensorType.getWithValueSemantics());
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ public:
|
|||
LogicalResult matchAndRewrite(PrimCallMethodOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||
op.getReceiver().getType().cast<NnModuleType>().getClassName());
|
||||
cast<NnModuleType>(op.getReceiver().getType()).getClassName());
|
||||
assert(classType && "malformed module -- missing ClassTypeOp");
|
||||
func::FuncOp func;
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
|
@ -94,7 +94,7 @@ class PrepareForGlobalizeObjectGraphPass
|
|||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<PrimCallMethodOp>();
|
||||
target.addDynamicallyLegalOp<func::ConstantOp>(
|
||||
[](func::ConstantOp op) { return !op.getType().isa<FunctionType>(); });
|
||||
[](func::ConstantOp op) { return !isa<FunctionType>(op.getType()); });
|
||||
target.addIllegalOp<func::CallIndirectOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ public:
|
|||
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
|
||||
|
||||
// Create IndexPut_Op
|
||||
BaseTensorType tensorType = op.getType().cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getType());
|
||||
Type rangeType = tensorType.getWithSizesAndDtype(
|
||||
{kUnknownSize}, tensorType.getOptionalDtype());
|
||||
Value range = rewriter.create<AtenArangeStartStepOp>(
|
||||
|
@ -130,8 +130,7 @@ public:
|
|||
|
||||
// Create IndexPut_Op
|
||||
// Convert indexNum to indexTensor for the selectOp
|
||||
BaseTensorType selectOutTy =
|
||||
selectOp.getType().template cast<BaseTensorType>();
|
||||
BaseTensorType selectOutTy = cast<BaseTensorType>(selectOp.getType());
|
||||
SmallVector<int64_t> empty;
|
||||
auto dtype = getTypeForTorchType(selectOp.getContext(),
|
||||
selectOp.getIndex().getType());
|
||||
|
@ -141,7 +140,7 @@ public:
|
|||
selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
|
||||
|
||||
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
|
||||
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(op->getResultTypes()[0]);
|
||||
SmallVector<Value> indicesVector(dim, noneVal);
|
||||
indicesVector.push_back(indexTensor);
|
||||
|
||||
|
|
|
@ -26,8 +26,8 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter,
|
|||
Location loc, Value overwriterTensor,
|
||||
Value overwrittenTensor) {
|
||||
Type overwriterTensorType = overwriterTensor.getType();
|
||||
Type overwrittenTensorType = overwrittenTensor.getType()
|
||||
.dyn_cast<NonValueTensorType>()
|
||||
Type overwrittenTensorType =
|
||||
dyn_cast<NonValueTensorType>(overwrittenTensor.getType())
|
||||
.getWithValueSemantics();
|
||||
if (overwriterTensorType != overwrittenTensorType) {
|
||||
overwriterTensor = rewriter.create<TensorStaticInfoCastOp>(
|
||||
|
@ -58,7 +58,7 @@ operatorOpHasValueSemantics(OperatorOp opOp,
|
|||
std::optional<SymbolTable> extraLibrary) {
|
||||
if (!extraLibrary.has_value())
|
||||
return false;
|
||||
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
|
||||
auto opName = cast<StringAttr>(opOp->getAttr("name")).getValue();
|
||||
std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix(
|
||||
LibraryFunctionKind::HasValueSemantics) +
|
||||
Twine(opName))
|
||||
|
@ -96,8 +96,8 @@ public:
|
|||
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
||||
opOperand.get()));
|
||||
} else if (auto listType = dyn_cast<ListType>(operandType)) {
|
||||
if (!(listType.getContainedType().isa<NonValueTensorType>() ||
|
||||
listType.getContainedType().isa<OptionalType>()))
|
||||
if (!(isa<NonValueTensorType>(listType.getContainedType()) ||
|
||||
isa<OptionalType>(listType.getContainedType())))
|
||||
continue;
|
||||
|
||||
// Construct a new list whose elements are value tensors copied from
|
||||
|
@ -116,7 +116,7 @@ public:
|
|||
|
||||
// TODO: Handle optional type in list type.
|
||||
if (auto optionalType =
|
||||
listType.getContainedType().dyn_cast<OptionalType>()) {
|
||||
dyn_cast<OptionalType>(listType.getContainedType())) {
|
||||
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
||||
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
||||
})) {
|
||||
|
@ -129,7 +129,7 @@ public:
|
|||
|
||||
auto newListElements = llvm::to_vector(llvm::map_range(
|
||||
listConstruct.getElements(), [&](Value tensor) -> Value {
|
||||
if (tensor.getType().isa<NonValueTensorType>()) {
|
||||
if (isa<NonValueTensorType>(tensor.getType())) {
|
||||
return rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
||||
tensor);
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ public:
|
|||
} else if (auto optionalType = dyn_cast<OptionalType>(operandType)) {
|
||||
// TODO: A more general way to handle the optional type is to
|
||||
// introduce a `copy.to_optional_vtensor` op.
|
||||
if (!optionalType.getContainedType().isa<NonValueTensorType>())
|
||||
if (!isa<NonValueTensorType>(optionalType.getContainedType()))
|
||||
continue;
|
||||
|
||||
// Create a new optional value whose input is a value tensor copied
|
||||
|
@ -160,7 +160,7 @@ public:
|
|||
"derefine");
|
||||
}
|
||||
|
||||
if (!derefine.getOperand().getType().isa<NonValueTensorType>())
|
||||
if (!isa<NonValueTensorType>(derefine.getOperand().getType()))
|
||||
continue;
|
||||
auto newOperand = rewriter.create<CopyToValueTensorOp>(
|
||||
op->getLoc(), derefine.getOperand());
|
||||
|
@ -172,7 +172,7 @@ public:
|
|||
// Convert all results.
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
for (Value result : op->getResults()) {
|
||||
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
|
||||
auto tensorType = dyn_cast<NonValueTensorType>(result.getType());
|
||||
if (!tensorType)
|
||||
continue;
|
||||
result.setType(tensorType.getWithValueSemantics());
|
||||
|
|
|
@ -84,7 +84,7 @@ class RefinePublicReturnPass
|
|||
}
|
||||
}
|
||||
|
||||
if (auto tensorType = newOperand.getType().dyn_cast<BaseTensorType>()) {
|
||||
if (auto tensorType = dyn_cast<BaseTensorType>(newOperand.getType())) {
|
||||
newOperands.push_back(
|
||||
copyTensorToType(builder, returnOp->getLoc(),
|
||||
tensorType.getWithValueSemantics(), newOperand));
|
||||
|
|
|
@ -118,7 +118,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
|
|||
assert(call.getNumResults() == 1 &&
|
||||
"Multiple results are packed in a tuple in Python!");
|
||||
Value result = call.getResult(0);
|
||||
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) {
|
||||
if (auto tupleType = dyn_cast<Torch::TupleType>(result.getType())) {
|
||||
auto unpack = b.create<PrimTupleUnpackOp>(
|
||||
loc, tupleType.getContainedTypes(), result);
|
||||
llvm::append_range(unpackedResults, unpack.getResults());
|
||||
|
@ -275,7 +275,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
|||
// for i in range(len(operand)):
|
||||
// adjusted_list.append(adjust(operand[i]))
|
||||
// return adjusted_list
|
||||
auto providedType = operand.getType().cast<Torch::ListType>();
|
||||
auto providedType = cast<Torch::ListType>(operand.getType());
|
||||
Value adjustedList =
|
||||
b.create<PrimListConstructOp>(loc, desiredListType, ValueRange({}));
|
||||
// Create a for-like PrimLoopOp.
|
||||
|
@ -312,7 +312,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
|||
// signature uses `Scalar` (see comments in torch_ods_gen.py for
|
||||
// explanation).
|
||||
if (isa<Torch::FloatType>(desiredType) &&
|
||||
operand.getType().isa<Torch::IntType>()) {
|
||||
isa<Torch::IntType>(operand.getType())) {
|
||||
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
|||
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
||||
Type desiredType) -> Value {
|
||||
if (isa<Torch::TupleType>(desiredType) &&
|
||||
operand.getType().isa<Torch::BaseTensorType>()) {
|
||||
isa<Torch::BaseTensorType>(operand.getType())) {
|
||||
Type intType = Torch::IntType::get(b.getContext());
|
||||
Type sizeListType = Torch::ListType::get(intType);
|
||||
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
|
||||
|
|
|
@ -41,8 +41,8 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
|||
auto desiredListType = dyn_cast<Torch::ListType>(desiredType);
|
||||
if (!desiredListType)
|
||||
return operand;
|
||||
if (operand.getType().isa<Torch::BaseTensorType>() &&
|
||||
desiredListType.getContainedType().isa<Torch::IntType>()) {
|
||||
if (isa<Torch::BaseTensorType>(operand.getType()) &&
|
||||
isa<Torch::IntType>(desiredListType.getContainedType())) {
|
||||
return b.create<AtenSizeOp>(loc, desiredType, operand);
|
||||
}
|
||||
return operand;
|
||||
|
|
|
@ -259,7 +259,7 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
|||
Type originalResultType = result.getType();
|
||||
Type updatedType;
|
||||
if (auto originalBaseTensorType =
|
||||
originalResultType.template dyn_cast<BaseTensorType>()) {
|
||||
dyn_cast<BaseTensorType>(originalResultType)) {
|
||||
// If we didn't get any new information, there is nothing left for us to do.
|
||||
updatedType = meetTensorTypes(originalBaseTensorType,
|
||||
cast<BaseTensorType>(newResultType));
|
||||
|
@ -267,7 +267,7 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
|||
return rewriter.notifyMatchFailure(
|
||||
calculateOp, "New type information does not refine old type");
|
||||
} else if (auto originalResultType =
|
||||
result.getType().template dyn_cast<Torch::NumberType>()) {
|
||||
dyn_cast<Torch::NumberType>(result.getType())) {
|
||||
if (!isa<Torch::FloatType, Torch::IntType>(newResultType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
calculateOp,
|
||||
|
|
|
@ -35,7 +35,7 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
|||
|
||||
// Calculate the updated type incorporating the new information.
|
||||
Type impliedTypeFromDtype;
|
||||
if (result.getType().isa<Torch::NumberType>()) {
|
||||
if (isa<Torch::NumberType>(result.getType())) {
|
||||
FailureOr<Type> torchType =
|
||||
getTorchTypeForScalarType(op->getContext(), dtypeScalarType);
|
||||
if (failed(torchType)) {
|
||||
|
@ -45,7 +45,7 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
|||
}
|
||||
impliedTypeFromDtype = *torchType;
|
||||
} else if (auto originalResultType =
|
||||
result.getType().dyn_cast<BaseTensorType>()) {
|
||||
dyn_cast<BaseTensorType>(result.getType())) {
|
||||
FailureOr<Type> builtinType =
|
||||
getTypeForScalarType(op->getContext(), dtypeScalarType);
|
||||
if (failed(builtinType)) {
|
||||
|
@ -168,12 +168,12 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto originalResultType = op.getResult().getType().cast<BaseTensorType>();
|
||||
auto originalResultType = cast<BaseTensorType>(op.getResult().getType());
|
||||
if (originalResultType.hasDtype())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "`PrimNumToTensorScalarOp` already has a dtype");
|
||||
|
||||
if (op.getA().getType().isa<Torch::NumberType>()) {
|
||||
if (isa<Torch::NumberType>(op.getA().getType())) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"`PrimNumToTensorScalarOp`'s input "
|
||||
"should have concrete Scalar Type.");
|
||||
|
|
|
@ -27,7 +27,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
MLIRContext *context = op.getContext();
|
||||
auto tensorType = self.getType().cast<BaseTensorType>();
|
||||
auto tensorType = cast<BaseTensorType>(self.getType());
|
||||
if (!tensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "unranked tensor");
|
||||
int64_t rank = tensorType.getSizes().size();
|
||||
|
@ -96,7 +96,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
|||
sizes.push_back(kUnknownSize);
|
||||
}
|
||||
|
||||
auto originalResultType = result.getType().cast<BaseTensorType>();
|
||||
auto originalResultType = cast<BaseTensorType>(result.getType());
|
||||
auto impliedTypesFromShape =
|
||||
cast<BaseTensorType>(originalResultType)
|
||||
.getWithSizesAndDtype(ArrayRef(sizes),
|
||||
|
|
|
@ -44,9 +44,9 @@ bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
|||
}
|
||||
|
||||
torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||
if (type.isa<Float32Type>())
|
||||
if (isa<Float32Type>(type))
|
||||
return torch_upstream::ScalarType::Float;
|
||||
if (type.isa<Float64Type>())
|
||||
if (isa<Float64Type>(type))
|
||||
return torch_upstream::ScalarType::Double;
|
||||
if (type.isSignedInteger(64))
|
||||
return torch_upstream::ScalarType::Long;
|
||||
|
@ -64,11 +64,11 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
|||
return torch_upstream::ScalarType::Byte;
|
||||
if (type.isSignedInteger(8))
|
||||
return torch_upstream::ScalarType::Char;
|
||||
if (type.isa<QUInt8Type>())
|
||||
if (isa<QUInt8Type>(type))
|
||||
return torch_upstream::ScalarType::QUInt8;
|
||||
if (type.isa<QInt8Type>())
|
||||
if (isa<QInt8Type>(type))
|
||||
return torch_upstream::ScalarType::QInt8;
|
||||
if (type.isa<QInt32Type>())
|
||||
if (isa<QInt32Type>(type))
|
||||
return torch_upstream::ScalarType::QInt32;
|
||||
if (isa<ComplexType>(type)) {
|
||||
mlir::Type complexElemType = cast<ComplexType>(type).getElementType();
|
||||
|
@ -185,7 +185,7 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
|||
// Helper to convert a tensor to a specific scalar type.
|
||||
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
||||
Value input, Type dtype) {
|
||||
BaseTensorType origType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType origType = cast<BaseTensorType>(input.getType());
|
||||
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
||||
// `convertIntVal` contains the corresponding integer for the dtype which is
|
||||
// used by the aten.to.dtype op.
|
||||
|
@ -202,7 +202,7 @@ bool Torch::isBuiltInType(Type type) {
|
|||
}
|
||||
|
||||
std::optional<unsigned> Torch::getTensorRank(Value tensor) {
|
||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
|
||||
if (!tensorType.hasSizes())
|
||||
return std::nullopt;
|
||||
return tensorType.getSizes().size();
|
||||
|
@ -279,7 +279,7 @@ SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
|
|||
// Return the squeezed tensor or failure.
|
||||
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Location loc, int64_t dim, Value input) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(loc, "input tensor must have size");
|
||||
}
|
||||
|
@ -314,7 +314,7 @@ FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
// Return the unsqueezed tensor or failure.
|
||||
FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value input, Value dim) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op, "input tensor must have size");
|
||||
}
|
||||
|
@ -348,9 +348,9 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
|
|||
SmallVector<int64_t> &resultShape,
|
||||
SmallVector<Value> &resultShapeValue) {
|
||||
SmallVector<int64_t> shapeA{
|
||||
inputA.getType().cast<BaseTensorType>().getSizes()};
|
||||
cast<BaseTensorType>(inputA.getType()).getSizes()};
|
||||
SmallVector<int64_t> shapeB{
|
||||
inputB.getType().cast<BaseTensorType>().getSizes()};
|
||||
cast<BaseTensorType>(inputB.getType()).getSizes()};
|
||||
unsigned rankA = shapeA.size();
|
||||
unsigned rankB = shapeB.size();
|
||||
unsigned minRank = rankA > rankB ? rankB : rankA;
|
||||
|
@ -504,9 +504,8 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
|||
BaseTensorType inputType, Value scalar) {
|
||||
assert(inputType.hasDtype() && "input must have dtype");
|
||||
SmallVector<int64_t> sizes;
|
||||
BaseTensorType rank0TensorTy =
|
||||
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())
|
||||
.cast<BaseTensorType>();
|
||||
BaseTensorType rank0TensorTy = cast<BaseTensorType>(
|
||||
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()));
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
||||
ValueRange{});
|
||||
|
@ -531,9 +530,9 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
|
|||
return rewriter.getF32Type();
|
||||
if (inputType.isBF16())
|
||||
return rewriter.getF32Type();
|
||||
if (inputType.isa<Float32Type>())
|
||||
if (isa<Float32Type>(inputType))
|
||||
return rewriter.getF32Type();
|
||||
if (inputType.isa<Float64Type>())
|
||||
if (isa<Float64Type>(inputType))
|
||||
return rewriter.getF64Type();
|
||||
if (inputType.isFloat8E5M2())
|
||||
return rewriter.getF32Type();
|
||||
|
|
|
@ -34,9 +34,9 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ToBuiltinTensorOp::verify() {
|
||||
auto resultType = getResult().getType().cast<TensorType>();
|
||||
auto resultType = cast<TensorType>(getResult().getType());
|
||||
auto operandType =
|
||||
getOperand().getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
|
||||
cast<Torch::ValueTensorType>(getOperand().getType()).toBuiltinTensor();
|
||||
if (!haveSameSizeAndElementType(resultType, operandType)) {
|
||||
return emitError()
|
||||
<< "operand and result must have the same size and dtype";
|
||||
|
@ -49,7 +49,7 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
|||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto resultType =
|
||||
operands[0].getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
|
||||
cast<Torch::ValueTensorType>(operands[0].getType()).toBuiltinTensor();
|
||||
if (!resultType)
|
||||
return failure();
|
||||
inferredReturnTypes.push_back(resultType);
|
||||
|
@ -62,8 +62,8 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
|||
|
||||
LogicalResult FromBuiltinTensorOp::verify() {
|
||||
auto resultType =
|
||||
getResult().getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
|
||||
auto operandType = getOperand().getType().cast<TensorType>();
|
||||
cast<Torch::ValueTensorType>(getResult().getType()).toBuiltinTensor();
|
||||
auto operandType = cast<TensorType>(getOperand().getType());
|
||||
if (!haveSameSizeAndElementType(resultType, operandType)) {
|
||||
return emitError()
|
||||
<< "operand and result must have the same size and dtype";
|
||||
|
|
|
@ -36,7 +36,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
|||
ValueRange inputs,
|
||||
Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
if (!inputs[0].getType().isa<Torch::BaseTensorType>())
|
||||
if (!isa<Torch::BaseTensorType>(inputs[0].getType()))
|
||||
return {};
|
||||
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
|
||||
});
|
||||
|
@ -44,7 +44,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
|||
Torch::ValueTensorType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<TensorType>());
|
||||
assert(isa<TensorType>(inputs[0].getType()));
|
||||
return builder.create<FromBuiltinTensorOp>(loc, type, inputs[0]);
|
||||
};
|
||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||
|
@ -64,13 +64,13 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
|
|||
if (!(type.getWidth() == 1 && type.isSignless()))
|
||||
return std::nullopt;
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<Torch::BoolType>());
|
||||
assert(isa<Torch::BoolType>(inputs[0].getType()));
|
||||
return builder.create<ToI1Op>(loc, inputs[0]).getResult();
|
||||
});
|
||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<IntegerType>());
|
||||
assert(isa<IntegerType>(inputs[0].getType()));
|
||||
return builder.create<FromI1Op>(loc, inputs[0]);
|
||||
};
|
||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||
|
@ -99,7 +99,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
|||
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<IntegerType>());
|
||||
assert(isa<IntegerType>(inputs[0].getType()));
|
||||
return builder.create<FromI64Op>(loc, inputs[0]);
|
||||
};
|
||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||
|
@ -116,13 +116,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
|
|||
[](OpBuilder &builder, Float64Type type, ValueRange inputs,
|
||||
Location loc) -> std::optional<Value> {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<Torch::FloatType>());
|
||||
assert(isa<Torch::FloatType>(inputs[0].getType()));
|
||||
return builder.create<ToF64Op>(loc, inputs[0]).getResult();
|
||||
});
|
||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<Float64Type>());
|
||||
assert(isa<Float64Type>(inputs[0].getType()));
|
||||
return builder.create<FromF64Op>(loc, inputs[0]);
|
||||
};
|
||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||
|
@ -153,7 +153,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
|
|||
auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<IntegerType>());
|
||||
assert(isa<IntegerType>(inputs[0].getType()));
|
||||
return builder.create<I64ToGeneratorOp>(loc, inputs[0]);
|
||||
};
|
||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||
|
|
|
@ -42,7 +42,7 @@ public:
|
|||
|
||||
// get inputs: lhs, rhsQuant, scales, zps
|
||||
Value lhs = adaptor.getOperands()[0];
|
||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsType = cast<RankedTensorType>(lhs.getType());
|
||||
if (!lhsType) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ public:
|
|||
int lhsReductDimSize = lhsShape.back();
|
||||
|
||||
Value rhsQuant = adaptor.getOperands()[1];
|
||||
auto rhsType = rhsQuant.getType().cast<RankedTensorType>();
|
||||
auto rhsType = cast<RankedTensorType>(rhsQuant.getType());
|
||||
if (!rhsType) {
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ public:
|
|||
if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth)))
|
||||
return failure();
|
||||
|
||||
auto rhsType = rhs.getType().dyn_cast<ValueTensorType>();
|
||||
auto rhsType = dyn_cast<ValueTensorType>(rhs.getType());
|
||||
if (!rhsType)
|
||||
return failure();
|
||||
|
||||
|
@ -88,7 +88,7 @@ public:
|
|||
ValueTensorType newRhsType = ValueTensorType::get(
|
||||
rewriter.getContext(), tensorShape, unpackedElementType);
|
||||
|
||||
auto elements = constOp.getValueAttr().dyn_cast<DenseIntElementsAttr>();
|
||||
auto elements = dyn_cast<DenseIntElementsAttr>(constOp.getValueAttr());
|
||||
if (!elements)
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -234,7 +234,7 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
|
|||
if (!globalOp.getValue().has_value())
|
||||
return globalOp.emitError("global op must have a value");
|
||||
|
||||
RankedTensorType tensorType = globalOp.getType().cast<RankedTensorType>();
|
||||
RankedTensorType tensorType = cast<RankedTensorType>(globalOp.getType());
|
||||
MemRefType memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
|
||||
|
@ -252,7 +252,7 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
|
|||
static LogicalResult
|
||||
bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp,
|
||||
OpBuilder &b, SmallVector<Operation *> &toErase) {
|
||||
RankedTensorType tensorType = globalLoadOp.getType().cast<RankedTensorType>();
|
||||
RankedTensorType tensorType = cast<RankedTensorType>(globalLoadOp.getType());
|
||||
MemRefType memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
|
||||
|
@ -271,7 +271,7 @@ bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp,
|
|||
OpBuilder &b,
|
||||
SmallVector<Operation *> &toErase) {
|
||||
RankedTensorType tensorType =
|
||||
globalStoreOp.getValue().getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(globalStoreOp.getValue().getType());
|
||||
MemRefType memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
|
||||
|
@ -300,7 +300,7 @@ class MLProgramBufferize : public MLProgramBufferizeBase<MLProgramBufferize> {
|
|||
SmallVector<Operation *> toErase;
|
||||
|
||||
auto walkResult = module.walk([&](ml_program::GlobalOp op) {
|
||||
if (auto type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto type = dyn_cast<RankedTensorType>(op.getType())) {
|
||||
if (!type.hasStaticShape()) {
|
||||
// If the ml_program.global has dynamically shaped tensor.
|
||||
op.emitError(
|
||||
|
@ -387,8 +387,8 @@ mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
|
|||
|
||||
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
|
||||
Value to) {
|
||||
auto memrefTypeFrom = from.getType().cast<MemRefType>();
|
||||
auto memrefTypeTo = to.getType().cast<MemRefType>();
|
||||
auto memrefTypeFrom = cast<MemRefType>(from.getType());
|
||||
auto memrefTypeTo = cast<MemRefType>(to.getType());
|
||||
(void)memrefTypeFrom;
|
||||
assert(memrefTypeFrom && memrefTypeTo &&
|
||||
memrefTypeFrom.getRank() == memrefTypeTo.getRank());
|
||||
|
|
Loading…
Reference in New Issue