mirror of https://github.com/llvm/torch-mlir
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label use of the builtin tnesor type for modeling unknown dtypes. The `!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor. The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic tensor. The new types look as follows syntactically: ``` // Least-static-information, non-value-semantic tensor. !torch.tensor // Explicit form of least-static-information variant. !torch.tensor<*,unk> // Least-static-information, value-semantic tensor. !torch.vtensor // Explicit form of least-static-information variant. !torch.vtensor<*,unk> // Fixed-set of allowable element types, with first-class support for // Torch's frontend signedness semantics. !torch.tensor<*,si32> // First-class support for unknown dtypes. !torch.tensor<[?,?,?],unk> // Standard MLIR representation of `?` for unknown dimensions. !torch.tensor<[?,2,?,4],unk> // Statically shaped / dtyped example. !torch.vtensor<[1,2,3,4],f32> ``` This required fairly significant changes throughout the compiler, but overall it is a big cleanup. We now have a much clearer layering of "the Torch frontend lowering" vs "lowering to std + linalg + etc.". At the C++ level, there is `ValueTensorType`, `NonValueTensorType`. We also have a helper `BaseTensorType` (kind of like ShapedType) which interoperates with those two. Included changes: - New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for creating torch tensor literals in the frontend. - Consistently use signedness for the types (except i1 which I didn't touch -- we need to sort out the situation with !basicpy.BoolType there anyway so will be attending to that soon) - Frontend can annotate whether an argument to the function has value semantics. We currently require this, as our backend contract does not currently allow us to even model the non-value-semantic case. Before, the value-semantic assumption was randomly injected in the middle of the pass pipeline. - Move ArrayToTensor (now called MaximizeValueSemantics) and RefinePublicReturn passes to torch dialect. - The TorchToStd and TorchToLinalg passes are now type conversions from `!torch.vtensor` to `tensor` and use the dialect conversion infra. The overall conversion pipeline is set up following the best practices of the "Type Conversions the Not-So-Hard Way" talk. This required introducing `torch-func-builtin-tensorize` and `torch-finalizing-builtin-tensorize` passes analogous to the upstream bufferization passes with the corresponding names (mostly just copypasta from there). - Misc Torch-level canonicalizations -- we now cleanly layer the lowering to std later in the pipeline, so we are gradually lessening our reliance on random std constant folding before we get to that point. Recommended review order: - New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp - New ops in TorchOps.td / TorchOps.cpp - Less important / more mechanical stuff - Frontend changes. - Pass changes/additions in `Torch/Transforms` and `Conversion/`pull/218/head
parent
b7b7fd4959
commit
370e3270ab
|
@ -524,24 +524,15 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
|||
|
||||
MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||
auto loc = getCurrentLocation();
|
||||
MlirAttribute valueAttribute = converTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirValue constTensorValue =
|
||||
funcBuilder->getGeneralConstant(loc, valueAttribute);
|
||||
|
||||
// Create an array from the tensor constant via the
|
||||
// numpy.create_array_from_tensor op.
|
||||
MlirType constArrayType =
|
||||
npcompNdArrayTypeGetFromShaped(mlirAttributeGetType(valueAttribute));
|
||||
MlirOperationState state = mlirOperationStateGet(
|
||||
toMlirStringRef("numpy.create_array_from_tensor"), loc);
|
||||
mlirOperationStateAddOperands(&state, 1, &constTensorValue);
|
||||
mlirOperationStateAddResults(&state, 1, &constArrayType);
|
||||
MlirOperation constArrayOp = mlirOperationCreate(&state);
|
||||
|
||||
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(constArrayOp);
|
||||
MlirValue constArrayValue = mlirOperationGetResult(constArrayOp, 0);
|
||||
funcBuilder->mapTensor(tensor, constArrayValue);
|
||||
return constArrayValue;
|
||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||
funcBuilder->getEntryBlock(), "torch.tensor", loc,
|
||||
npcompNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
|
||||
funcBuilder->mapTensor(tensor, tensorValue);
|
||||
return tensorValue;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, BackendSelect, m) {
|
||||
|
|
|
@ -180,16 +180,18 @@ static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
|||
auto tuple = py::cast<py::tuple>(pyArgAnnotations[i]);
|
||||
auto shape = tuple[0];
|
||||
auto dtype = tuple[1];
|
||||
auto hasValueSemantics = tuple[2];
|
||||
if (!shape.is_none()) {
|
||||
argAnnotations[i].shape = py::cast<std::vector<int64_t>>(shape);
|
||||
}
|
||||
if (!dtype.is_none()) {
|
||||
argAnnotations[i].dtype = convertToC10ScalarType(dtype);
|
||||
}
|
||||
argAnnotations[i].hasValueSemantics = py::cast<bool>(hasValueSemantics);
|
||||
};
|
||||
}
|
||||
|
||||
void ClassAnnotator::annotateShapesAndDtypes(c10::ClassType &rootClassType,
|
||||
void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
|
||||
std::vector<std::string> path,
|
||||
py::list argAnnotations) {
|
||||
if (path.size() == 0) {
|
||||
|
@ -279,6 +281,8 @@ std::string ArgAnnotation::toString(int argIndex) {
|
|||
} else {
|
||||
ss << "<none>\n";
|
||||
}
|
||||
ss << " hasValueSemantics = " << (hasValueSemantics ? "true" : "false")
|
||||
<< "\n";
|
||||
ss << "}\n";
|
||||
return ss.str();
|
||||
}
|
||||
|
@ -333,6 +337,6 @@ void torch_mlir::initClassAnnotatorBindings(py::module &m) {
|
|||
.def(py::init<>())
|
||||
.def("exportPath", &ClassAnnotator::exportPath)
|
||||
.def("exportNone", &ClassAnnotator::exportNone)
|
||||
.def("annotateShapesAndDtypes", &ClassAnnotator::annotateShapesAndDtypes)
|
||||
.def("annotateArgs", &ClassAnnotator::annotateArgs)
|
||||
.def("__repr__", &ClassAnnotator::toString);
|
||||
}
|
||||
|
|
|
@ -44,10 +44,26 @@ struct ArgAnnotation {
|
|||
// Each entry represents the size of each dimension of a tensor with known
|
||||
// rank. `-1` represents an unknown size along that dimension.
|
||||
c10::optional<std::vector<int64_t>> shape;
|
||||
|
||||
// If not None, represents information known about the dtype of the argument
|
||||
// (the argument must be a tensor).
|
||||
c10::optional<c10::ScalarType> dtype;
|
||||
|
||||
// If true, means that the user code will treat this argument as if it
|
||||
// has value semantics (the argument must be a tensor).
|
||||
//
|
||||
// In particular, this means that use code:
|
||||
// - expects the argument will not be mutated
|
||||
// - expects that any mutation to the argument internal to the program will
|
||||
// not be reflected externally.
|
||||
//
|
||||
// A value of `false` preserves the default Torch semantics and is a
|
||||
// safe default.
|
||||
//
|
||||
// TODO: Also add a "last use" / "dead" flag, which enables more powerful
|
||||
// optimizations like reusing the input buffer for scratch space.
|
||||
bool hasValueSemantics = false;
|
||||
|
||||
std::string toString(int argIndex);
|
||||
};
|
||||
|
||||
|
@ -135,14 +151,14 @@ public:
|
|||
// Annotate shapes and dtypes of the arguments of a method at path `path` from
|
||||
// `rootClassType`.
|
||||
//
|
||||
// `argAnnotations` should be a list of 2-tuples, with the first element
|
||||
// `argAnnotations` should be a list of 3-tuples, with the first element
|
||||
// being a list/tuple of integer sizes, and the second being a torch datatype
|
||||
// object, such as `torch.float32`, `torch.int8`, etc.
|
||||
// object, such as `torch.float32`, `torch.int8`, etc., and the last being
|
||||
// a "has value semantics" boolean.
|
||||
// These will be put into an `ArgAnnotation` struct -- see there for
|
||||
// precise definitions of the promised semantics of each entry.
|
||||
void annotateShapesAndDtypes(c10::ClassType &rootClassType,
|
||||
std::vector<std::string> path,
|
||||
py::list argAnnotations);
|
||||
void annotateArgs(c10::ClassType &rootClassType,
|
||||
std::vector<std::string> path, py::list argAnnotations);
|
||||
|
||||
// The annotations collected so far.
|
||||
const ClassAnnotationMap &getAnnotationMap();
|
||||
|
|
|
@ -339,11 +339,13 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
|
||||
// Import the bulk tensor representation.
|
||||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation constant = createMlirOperationAtEnd(
|
||||
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
|
||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp =
|
||||
createMlirOperationAtEnd(importBlock, "torch.tensor", loc,
|
||||
npcompNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirValue tensorReprValue = mlirOperationGetResult(constant, 0);
|
||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||
|
||||
// Construct the complete tensor value. This is trivial for most tensors, but
|
||||
// for quantized tensors (and probably sparse too, TBD) there is more for us
|
||||
|
@ -355,9 +357,9 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
// compiler stages that are building a statically modeled quantization
|
||||
// representation will need to convert this to their representation.
|
||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||
MlirType quantizedTensorType = mlirRankedTensorTypeGetChecked(
|
||||
loc, shape.size(), shape.data(),
|
||||
typeMapper.mapFromTorchScalarType(tensor.scalar_type()), {nullptr});
|
||||
MlirType quantizedTensorType = npcompNonValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
|
||||
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
|
||||
|
@ -375,12 +377,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
tensorValue = tensorReprValue;
|
||||
}
|
||||
|
||||
// Convert the tensor to ndarray to match Torch's default-mutable semantics.
|
||||
MlirOperation ndarray = createMlirOperationAtEnd(
|
||||
importBlock, "numpy.create_array_from_tensor", loc,
|
||||
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
|
||||
tensorValue);
|
||||
return mlirOperationGetResult(ndarray, 0);
|
||||
return tensorValue;
|
||||
}
|
||||
|
||||
void IValueImporter::importMethod(torch::jit::Function *function,
|
||||
|
@ -486,17 +483,31 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
if (!annotation || !annotation->argAnnotations.has_value()) {
|
||||
return {nullptr};
|
||||
}
|
||||
auto &shape = annotation->argAnnotations.value()[argIndex].shape;
|
||||
auto &dtype = annotation->argAnnotations.value()[argIndex].dtype;
|
||||
c10::optional<std::vector<int64_t>> &maybeShape =
|
||||
annotation->argAnnotations.value()[argIndex].shape;
|
||||
c10::optional<c10::ScalarType> &maybeDtype =
|
||||
annotation->argAnnotations.value()[argIndex].dtype;
|
||||
bool hasValueSemantics =
|
||||
annotation->argAnnotations.value()[argIndex].hasValueSemantics;
|
||||
|
||||
// TODO: Handle unranked tensors and tensors with unknown dtype (but
|
||||
// possibly known ranks/sizes).
|
||||
if (!shape || !dtype) {
|
||||
if (!maybeShape || !maybeDtype) {
|
||||
return {nullptr};
|
||||
}
|
||||
auto typeBound = npcompNdArrayTypeGetRanked(
|
||||
shape->size(), shape->data(),
|
||||
TypeMapper(context).mapFromTorchScalarType(
|
||||
mlirLocationUnknownGet(context), *dtype));
|
||||
|
||||
std::vector<int64_t> shape = *maybeShape;
|
||||
MlirType dtype = TypeMapper(context).mapFromTorchScalarType(
|
||||
mlirLocationUnknownGet(context), *maybeDtype);
|
||||
MlirType typeBound;
|
||||
if (hasValueSemantics) {
|
||||
typeBound = npcompValueTensorTypeGet(context, shape.size(),
|
||||
shape.data(), dtype);
|
||||
} else {
|
||||
typeBound = npcompNonValueTensorTypeGet(context, shape.size(),
|
||||
shape.data(), dtype);
|
||||
}
|
||||
|
||||
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
||||
"torch.type_bound", mlirTypeAttrGet(typeBound));
|
||||
return mlirDictionaryAttrGet(context, 1, &typeBoundAttr);
|
||||
|
|
|
@ -44,19 +44,15 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
|||
using c10::ScalarType;
|
||||
switch (scalarType) {
|
||||
case ScalarType::Byte:
|
||||
// TODO: convert to mlirIntegerTypeUnsignedGet once supported.
|
||||
return mlirIntegerTypeGet(context, 8);
|
||||
return mlirIntegerTypeUnsignedGet(context, 8);
|
||||
case ScalarType::Char:
|
||||
return mlirIntegerTypeGet(context, 8);
|
||||
return mlirIntegerTypeSignedGet(context, 8);
|
||||
case ScalarType::Short:
|
||||
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
||||
return mlirIntegerTypeGet(context, 16);
|
||||
return mlirIntegerTypeSignedGet(context, 16);
|
||||
case ScalarType::Int:
|
||||
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
||||
return mlirIntegerTypeGet(context, 32);
|
||||
return mlirIntegerTypeSignedGet(context, 32);
|
||||
case ScalarType::Long:
|
||||
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
||||
return mlirIntegerTypeGet(context, 64);
|
||||
return mlirIntegerTypeSignedGet(context, 64);
|
||||
case ScalarType::Bool:
|
||||
return npcompBoolTypeGet(context);
|
||||
case ScalarType::Double:
|
||||
|
@ -128,19 +124,21 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
case TypeKind::TensorType: {
|
||||
auto tensorType = torchType->cast<c10::TensorType>();
|
||||
// Element type.
|
||||
MlirType elementType;
|
||||
MlirType elementType = {nullptr};
|
||||
if (tensorType->scalarType()) {
|
||||
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
|
||||
if (mlirTypeIsNull(elementType))
|
||||
return {nullptr};
|
||||
} else {
|
||||
elementType = npcompAnyDtypeTypeGet(context);
|
||||
}
|
||||
// Sizes.
|
||||
auto &sizes = tensorType->symbolic_sizes();
|
||||
if (!sizes.rank()) {
|
||||
// Unranked.
|
||||
return npcompNdArrayTypeGetUnranked(elementType);
|
||||
return npcompNonValueTensorTypeGet(context,
|
||||
/*numSizes=*/0,
|
||||
/*optionalSizes=*/nullptr,
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
}
|
||||
// Ranked with possibly dynamic dims.
|
||||
auto &symbolicShape = tensorType->symbolic_sizes();
|
||||
|
@ -150,7 +148,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
auto shapeSymbol = symbolicShape[i];
|
||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||
}
|
||||
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
||||
return npcompNonValueTensorTypeGet(context, dims.size(),
|
||||
/*optionalSizes=*/dims.data(),
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
}
|
||||
case TypeKind::ClassType: {
|
||||
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
||||
|
@ -214,7 +215,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|||
// just erase them and let the compiler decide.
|
||||
|
||||
auto sizes = tensor.sizes();
|
||||
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
|
||||
return npcompNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
|
||||
elementType);
|
||||
}
|
||||
|
||||
MlirType
|
||||
|
@ -243,7 +245,7 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
|||
outputTypes.size(), outputTypes.data());
|
||||
}
|
||||
|
||||
MlirAttribute torch_mlir::converTensorToMlirElementsAttr(at::Tensor tensor,
|
||||
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||
MlirLocation loc) {
|
||||
MlirContext context = mlirLocationGetContext(loc);
|
||||
TypeMapper typeMapper(context);
|
||||
|
@ -273,12 +275,12 @@ MlirAttribute torch_mlir::converTensorToMlirElementsAttr(at::Tensor tensor,
|
|||
elementType = mlirIntegerTypeGet(context, 1);
|
||||
} else if (tensor.scalar_type() == ScalarType::QInt8) {
|
||||
// This function returns the underlying integer representation of the tensor
|
||||
// as an elements attr. That underlying representation is of type i8
|
||||
// as an elements attr. That underlying representation is of type si8
|
||||
// for a torch.qint8 tensor.
|
||||
// Caller code is responsible for materializing the proper op that
|
||||
// incorporates the quantization scheme to create a tensor of `!torch.qint8`
|
||||
// element type.
|
||||
elementType = mlirIntegerTypeGet(context, 8);
|
||||
elementType = mlirIntegerTypeSignedGet(context, 8);
|
||||
} else {
|
||||
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
|
||||
}
|
||||
|
@ -343,7 +345,7 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
|
|||
case torch::jit::AttributeKind::s:
|
||||
return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol)));
|
||||
case torch::jit::AttributeKind::t:
|
||||
return converTensorToMlirElementsAttr(node->t(symbol), loc);
|
||||
return convertTensorToMlirElementsAttr(node->t(symbol), loc);
|
||||
default: {
|
||||
std::stringstream msg;
|
||||
msg << "unhandled: value attribute kind " << toString(kind);
|
||||
|
|
|
@ -59,7 +59,7 @@ MlirType getFunctionTypeFromSchema(MlirContext context,
|
|||
const c10::FunctionSchema &schema);
|
||||
|
||||
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
||||
MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor,
|
||||
MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||
MlirLocation loc);
|
||||
|
||||
MlirAttribute importAttribute(MlirLocation loc, torch::jit::Node *node,
|
||||
|
|
|
@ -16,8 +16,8 @@ class MmModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
@ -39,7 +39,7 @@ class TanhModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, -1], torch.float32),
|
||||
([2, 3, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.tanh(x)
|
||||
|
@ -56,8 +56,8 @@ class MmTanhModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.tanh(self.matmul(lhs, rhs))
|
||||
|
|
|
@ -23,7 +23,7 @@ class Mlp1LayerModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.tanh0(self.fc0(x))
|
||||
|
@ -45,7 +45,7 @@ class Mlp2LayerModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
x = self.tanh0(self.fc0(x))
|
||||
|
|
|
@ -28,7 +28,7 @@ class QuantizedMLP(nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 16], torch.float32),
|
||||
([1, 16], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
x = self.quantize(x)
|
||||
|
|
|
@ -21,7 +21,7 @@ class ResNet18Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, 3, -1, -1], torch.float32),
|
||||
([-1, 3, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
|
|
@ -39,7 +39,7 @@ torch.jit.save(recursivescriptmodule, '/tmp/foo.pt')
|
|||
|
||||
class_annotator.exportNone(recursivescriptmodule._c._type())
|
||||
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
|
||||
class_annotator.annotateShapesAndDtypes(recursivescriptmodule._c._type(), ['forward'], [
|
||||
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
||||
None,
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32),
|
||||
|
|
|
@ -32,7 +32,7 @@ torch.jit.save(recursivescriptmodule, '/tmp/foo.pt')
|
|||
|
||||
class_annotator.exportNone(recursivescriptmodule._c._type())
|
||||
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
|
||||
class_annotator.annotateShapesAndDtypes(recursivescriptmodule._c._type(), ['forward'], [
|
||||
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
||||
None,
|
||||
([2, 3, -1], torch.float32)
|
||||
])
|
||||
|
|
|
@ -32,7 +32,7 @@ torch.jit.save(recursivescriptmodule, '/tmp/foo.pt')
|
|||
|
||||
class_annotator.exportNone(recursivescriptmodule._c._type())
|
||||
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
|
||||
class_annotator.annotateShapesAndDtypes(recursivescriptmodule._c._type(), ['forward'], [
|
||||
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
||||
None,
|
||||
([2, 3, -1], torch.float32)
|
||||
])
|
||||
|
|
|
@ -77,7 +77,7 @@ def _recursively_extract_annotations(
|
|||
if hasattr(method, '_npcomp_export'):
|
||||
class_annotator.exportPath(scripted._c._type(), [method_name])
|
||||
if hasattr(method, '_npcomp_arg_annotations'):
|
||||
class_annotator.annotateShapesAndDtypes(
|
||||
class_annotator.annotateArgs(
|
||||
scripted._c._type(), [method_name],
|
||||
method._npcomp_arg_annotations)
|
||||
# Recurse.
|
||||
|
|
|
@ -440,7 +440,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
|
||||
# Misc tensor ops.
|
||||
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::dim : (Tensor) -> (int)")
|
||||
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||
|
||||
# Primitive ops
|
||||
|
@ -452,7 +452,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::mul.float : (float, float) -> (float)")
|
||||
emit("aten::lt.float_int : (float, int) -> (bool)")
|
||||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||
emit("aten::len.t : (t[]) -> (int)", has_canonicalizer=True)
|
||||
emit("aten::len.t : (t[]) -> (int)",
|
||||
has_folder=True,
|
||||
has_canonicalizer=True)
|
||||
|
||||
|
||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||
|
|
|
@ -13,7 +13,6 @@ with mb.capture_function("arange_test", []) as f:
|
|||
x = torch.arange(10)
|
||||
f.returns([x])
|
||||
|
||||
# CHECK: %[[CST:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi64>
|
||||
# CHECK: %[[R:.*]] = numpy.create_array_from_tensor %[[CST]]
|
||||
# CHECK: return %[[R]]
|
||||
# CHECK: %[[T:.*]] = torch.tensor(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi64>) : !torch.tensor<[10],si64>
|
||||
# CHECK: return %[[T]]
|
||||
mb.module.operation.print()
|
||||
|
|
|
@ -19,17 +19,15 @@ with mb.capture_function("add3", [t0, t1, t2]) as f:
|
|||
f.returns([t3])
|
||||
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
# CHECK-LABEL: func @add3(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[1,2,3,4]:f32>, %[[VAL_1:.*]]: !numpy.ndarray<[1,2,3,4]:f32>,
|
||||
# CHECK-SAME: %[[VAL_2:.*]]: !numpy.ndarray<[1,2,3,4]:f32>) -> !numpy.ndarray<[1,2,3,4]:f32> {
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[1,2,3,4],f32>, %[[VAL_1:.*]]: !torch.tensor<[1,2,3,4],f32>,
|
||||
# CHECK-SAME: %[[VAL_2:.*]]: !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32> {
|
||||
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_4:.*]] = constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
# CHECK: %[[VAL_5:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_6:.*]] = constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
# CHECK: %[[VAL_7:.*]] = numpy.create_array_from_tensor %[[VAL_4]] : (tensor<1x2x3x4xf32>) -> !numpy.ndarray<[1,2,3,4]:f32>
|
||||
# CHECK: %[[VAL_8:.*]] = torch.operator "aten.add.out"(%[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_7]]) : (!numpy.ndarray<[1,2,3,4]:f32>, !numpy.ndarray<[1,2,3,4]:f32>, i64, !numpy.ndarray<[1,2,3,4]:f32>) -> !numpy.ndarray<[1,2,3,4]:f32>
|
||||
# CHECK: %[[VAL_9:.*]] = numpy.create_array_from_tensor %[[VAL_6]] : (tensor<1x2x3x4xf32>) -> !numpy.ndarray<[1,2,3,4]:f32>
|
||||
# CHECK: %[[VAL_10:.*]] = torch.operator "aten.add.out"(%[[VAL_8]], %[[VAL_2]], %[[VAL_5]], %[[VAL_9]]) : (!numpy.ndarray<[1,2,3,4]:f32>, !numpy.ndarray<[1,2,3,4]:f32>, i64, !numpy.ndarray<[1,2,3,4]:f32>) -> !numpy.ndarray<[1,2,3,4]:f32>
|
||||
# CHECK: return %[[VAL_10]] : !numpy.ndarray<[1,2,3,4]:f32>
|
||||
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_5:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_6:.*]] = torch.operator "aten.add.out"(%[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_5]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, i64, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_7:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_8:.*]] = torch.operator "aten.add.out"(%[[VAL_6]], %[[VAL_2]], %[[VAL_4]], %[[VAL_7]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, i64, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: return %[[VAL_8]] : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: }
|
||||
|
||||
print(mb.module)
|
||||
|
|
|
@ -21,5 +21,5 @@ with mb.capture_function("bn2d", [ones]) as f:
|
|||
# behavior.
|
||||
# CHECK-LABEL: @bn2d
|
||||
# CHECK: %[[RESULT:.*]]:3 = torch.operator "aten.native_batch_norm"(%arg0
|
||||
# CHECK: return %[[RESULT]]#0 : !numpy.ndarray<[42,123,4,5]:f32>
|
||||
# CHECK: return %[[RESULT]]#0 : !torch.tensor<[42,123,4,5],f32>
|
||||
print(mb.module)
|
||||
|
|
|
@ -32,26 +32,25 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
|||
|
||||
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
# CHECK-LABEL: func @conv2d_fwd(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
||||
# CHECK: %[[VAL_1:.*]] = constant opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>
|
||||
# CHECK: %[[VAL_2:.*]] = constant opaque<"_", "0xDEADBEEF"> : tensor<4xf32>
|
||||
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_5:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_6:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_7:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_8:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_9:.*]] = basicpy.bool_constant false
|
||||
# CHECK: %[[VAL_10:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_11:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_12:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_13:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
|
||||
# CHECK: %[[VAL_14:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
|
||||
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_16:.*]] = basicpy.build_list %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_17:.*]] = basicpy.build_list %[[VAL_7]], %[[VAL_8]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_18:.*]] = basicpy.build_list %[[VAL_10]], %[[VAL_11]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_19:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[VAL_17]], %[[VAL_9]], %[[VAL_18]], %[[VAL_12]]) : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
|
||||
# CHECK: return %[[VAL_19]] : !numpy.ndarray<[3,4,8,8]:f32>
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[3,16,10,10],f32>) -> !torch.tensor<[3,4,8,8],f32> {
|
||||
# CHECK: %[[VAL_1:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_2:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_3:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_4:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_5:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_6:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_7:.*]] = basicpy.bool_constant false
|
||||
# CHECK: %[[VAL_8:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_9:.*]] = constant 0 : i64
|
||||
# CHECK: %[[VAL_10:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_11:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
|
||||
# CHECK: %[[VAL_12:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
|
||||
# CHECK: %[[VAL_13:.*]] = basicpy.build_list %[[VAL_1]], %[[VAL_2]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_14:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_16:.*]] = basicpy.build_list %[[VAL_8]], %[[VAL_9]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType, !basicpy.ListType, i64) -> !torch.tensor<[3,4,8,8],f32>
|
||||
# CHECK: return %[[VAL_17]] : !torch.tensor<[3,4,8,8],f32>
|
||||
# CHECK: }
|
||||
|
||||
mb.module.operation.print(large_elements_limit=2)
|
||||
|
|
|
@ -26,19 +26,19 @@ recursivescriptmodule = torch.jit.script(test_module)
|
|||
annotator = torch_mlir.ClassAnnotator()
|
||||
class_type = recursivescriptmodule._c._type()
|
||||
try:
|
||||
annotator.annotateShapesAndDtypes(class_type, [], [])
|
||||
annotator.annotateArgs(class_type, [], [])
|
||||
except Exception as e:
|
||||
# CHECK: Empty annotated path. Can only annotate shapes/dtypes of a method of a class.
|
||||
print(e)
|
||||
|
||||
try:
|
||||
annotator.annotateShapesAndDtypes(class_type, ['forward'], [None])
|
||||
annotator.annotateArgs(class_type, ['forward'], [None])
|
||||
except Exception as e:
|
||||
# CHECK: Arg annotations should have one entry per function parameter (including self).
|
||||
print(e)
|
||||
|
||||
try:
|
||||
annotator.annotateShapesAndDtypes(class_type, ['forward'], [None, ([3, 4], 42)])
|
||||
annotator.annotateArgs(class_type, ['forward'], [None, ([3, 4], 42, False)])
|
||||
except Exception as e:
|
||||
# This is just the raw repr of the object in quotes.
|
||||
# CHECK: unsupported scalar type '42'
|
|
@ -24,11 +24,11 @@ annotator = torch_mlir.ClassAnnotator()
|
|||
class_type = recursivescriptmodule._c._type()
|
||||
# CHECK: func private @__torch__.TestModule.forward(
|
||||
# CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">,
|
||||
# CHECK-SAME: %arg1: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[?,1024]:i8>}
|
||||
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>}
|
||||
# CHECK-SAME: ) -> !basicpy.NoneType
|
||||
annotator.annotateShapesAndDtypes(class_type, ['forward'], [
|
||||
annotator.annotateArgs(class_type, ['forward'], [
|
||||
None,
|
||||
((-1, 1024), torch.int8),
|
||||
((-1, 1024), torch.int8, True),
|
||||
])
|
||||
|
||||
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
|
@ -30,7 +30,7 @@ class TestModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.s = Submodule()
|
||||
|
||||
def forward(self, tensor):
|
||||
def forward(self, tensor, value_tensor):
|
||||
return self.s.forward()
|
||||
|
||||
|
||||
|
@ -43,9 +43,10 @@ class_type = recursivescriptmodule._c._type()
|
|||
annotator.exportNone(class_type)
|
||||
annotator.exportPath(class_type, ['s', 'exported'])
|
||||
annotator.exportPath(class_type, ['s', 'forward'])
|
||||
annotator.annotateShapesAndDtypes(class_type, ['forward'], [
|
||||
annotator.annotateArgs(class_type, ['forward'], [
|
||||
None,
|
||||
((1024, 2), torch.float32),
|
||||
((1024, 2), torch.float32, False),
|
||||
((42, -1, 7), torch.int8, True),
|
||||
])
|
||||
|
||||
# "Change detector" test + "documentation" for the repr of `ClassAnnotator`.
|
||||
|
@ -91,10 +92,17 @@ annotator.annotateShapesAndDtypes(class_type, ['forward'], [
|
|||
# CHECK-NEXT: ArgAnnotation(0) {
|
||||
# CHECK-NEXT: dtype = <none>
|
||||
# CHECK-NEXT: shape = <none>
|
||||
# CHECK-NEXT: hasValueSemantics = false
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: ArgAnnotation(1) {
|
||||
# CHECK-NEXT: dtype = Float
|
||||
# CHECK-NEXT: shape = [1024, 2]
|
||||
# CHECK-NEXT: hasValueSemantics = false
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: ArgAnnotation(2) {
|
||||
# CHECK-NEXT: dtype = Char
|
||||
# CHECK-NEXT: shape = [42, -1, 7]
|
||||
# CHECK-NEXT: hasValueSemantics = true
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: }
|
||||
|
|
|
@ -17,8 +17,8 @@ class MmModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.float32),
|
||||
([4, 5], torch.float32),
|
||||
([3, 4], torch.float32, False),
|
||||
([4, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
@ -40,10 +40,12 @@ print(annotator)
|
|||
# CHECK: ArgAnnotation(1) {
|
||||
# CHECK: dtype = Float
|
||||
# CHECK: shape = [3, 4]
|
||||
# CHECK: hasValueSemantics = false
|
||||
# CHECK: }
|
||||
# CHECK: ArgAnnotation(2) {
|
||||
# CHECK: dtype = Float
|
||||
# CHECK: shape = [4, 5]
|
||||
# CHECK: hasValueSemantics = true
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
|
|
|
@ -14,19 +14,19 @@ mb = torch_mlir.ModuleBuilder()
|
|||
# Interesting test case, where a function calls a method.
|
||||
|
||||
# CHECK-LABEL: func private @__torch__.TestModule.forward
|
||||
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
|
||||
# CHECK: %[[F:.*]] = constant @__torch__.calls_method : (!torch.nn.Module<"__torch__.TestModule">, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType
|
||||
# CHECK: %[[RET:.*]] = call_indirect %[[F]](%[[ARG0]], %[[ARG1]]) : (!torch.nn.Module<"__torch__.TestModule">, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType
|
||||
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !basicpy.NoneType {
|
||||
# CHECK: %[[F:.*]] = constant @__torch__.calls_method : (!torch.nn.Module<"__torch__.TestModule">, !torch.tensor) -> !basicpy.NoneType
|
||||
# CHECK: %[[RET:.*]] = call_indirect %[[F]](%[[ARG0]], %[[ARG1]]) : (!torch.nn.Module<"__torch__.TestModule">, !torch.tensor) -> !basicpy.NoneType
|
||||
# CHECK: return %[[RET]] : !basicpy.NoneType
|
||||
# CHECK: }
|
||||
# CHECK-LABEL: func private @__torch__.TestModule.method
|
||||
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
|
||||
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !basicpy.NoneType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: return %[[RET]] : !basicpy.NoneType
|
||||
# CHECK: }
|
||||
# CHECK-LABEL: func private @__torch__.calls_method
|
||||
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[ARG0]]["method"] (%[[ARG1]]) : !torch.nn.Module<"__torch__.TestModule">, (!numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType
|
||||
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !basicpy.NoneType {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[ARG0]]["method"] (%[[ARG1]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.tensor) -> !basicpy.NoneType
|
||||
# CHECK: return %[[RET]] : !basicpy.NoneType
|
||||
# CHECK: }
|
||||
|
||||
|
|
|
@ -12,14 +12,14 @@ import torch_mlir
|
|||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func private @__torch__.TestModule.forward
|
||||
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: %[[VAL_2:.*]] = constant @__torch__.identity : (!numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: %[[VAL_3:.*]] = call_indirect %[[VAL_2]](%[[ARG1]]) : (!numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: return %[[VAL_3]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
# CHECK: %[[VAL_2:.*]] = constant @__torch__.identity : (!torch.tensor) -> !torch.tensor
|
||||
# CHECK: %[[VAL_3:.*]] = call_indirect %[[VAL_2]](%[[ARG1]]) : (!torch.tensor) -> !torch.tensor
|
||||
# CHECK: return %[[VAL_3]] : !torch.tensor
|
||||
# CHECK: }
|
||||
# CHECK-LABEL: func private @__torch__.identity
|
||||
# CHECK-SAME: (%[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: return %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK-SAME: (%[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
# CHECK: return %[[ARG]] : !torch.tensor
|
||||
# CHECK: }
|
||||
|
||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||
|
|
|
@ -22,8 +22,8 @@ mb = torch_mlir.ModuleBuilder()
|
|||
# we don't need to capture their names when FileCheck testing).
|
||||
|
||||
# CHECK-LABEL: func private @__torch__.TestModule.forward
|
||||
# CHECK-SAME: (%[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: return %[[X]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK-SAME: (%[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[X:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
# CHECK: return %[[X]] : !torch.tensor
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||
|
|
|
@ -14,10 +14,10 @@ mb = torch_mlir.ModuleBuilder()
|
|||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# CHECK: %[[A:.*]] = numpy.create_array_from_tensor
|
||||
# CHECK: %[[T:.*]] = torch.tensor
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.slot "t1", %[[A]]
|
||||
# CHECK: torch.slot "t2", %[[A]]
|
||||
# CHECK: torch.slot "t1", %[[T]]
|
||||
# CHECK: torch.slot "t2", %[[T]]
|
||||
self.t1 = self.t2 = torch.tensor([10., 20.])
|
||||
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ class TestModule(torch.nn.Module):
|
|||
self.callee(self.t1, self.t2)
|
||||
# CHECK-LABEL: func private @__torch__.TestModule.callee(
|
||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">,
|
||||
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK-SAME: %[[X:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[Y:.*]]: !torch.tensor
|
||||
def callee(self, x, y):
|
||||
pass
|
||||
|
||||
|
|
|
@ -19,19 +19,17 @@ class TestModule(torch.nn.Module):
|
|||
2,
|
||||
bias_=False,
|
||||
dtype=torch.qint8)
|
||||
# CHECK-DAG: %[[SCALE:.*]] = basicpy.numeric_constant {{.*}} : f64
|
||||
# CHECK-DAG: %[[ZERO_POINT:.*]] = basicpy.numeric_constant 0 : i64
|
||||
# CHECK-DAG: %[[INT_REPR:.*]] = constant dense<{{.*}}> : tensor<2x5xi8>
|
||||
# CHECK-DAG: %[[WEIGHTS:.*]] = torch.per_tensor_affine.create %[[INT_REPR]], %[[SCALE]], %[[ZERO_POINT]] : tensor<2x5xi8>, f64, i64 -> tensor<2x5x!torch.qint8>
|
||||
# CHECK-DAG: %[[WEIGHTS_ARRAY:.*]] = numpy.create_array_from_tensor %[[WEIGHTS]] : (tensor<2x5x!torch.qint8>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK-DAG: %[[BIAS:.*]] = constant dense<{{.*}}> : tensor<2xf32>
|
||||
# CHECK-DAG: %[[BIAS_ARRAY:.*]] = numpy.create_array_from_tensor %[[BIAS]] : (tensor<2xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK-DAG: %[[LINEAR_PARAMS:.*]] = torch.linear_params.create %[[WEIGHTS_ARRAY]], %[[BIAS_ARRAY]] : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: %[[SCALE:.*]] = basicpy.numeric_constant {{.*}} : f64
|
||||
# CHECK: %[[ZERO_POINT:.*]] = basicpy.numeric_constant 0 : i64
|
||||
# CHECK: %[[INT_REPR:.*]] = torch.tensor({{.*}}) : !torch.tensor<[2,5],si8>
|
||||
# CHECK: %[[WEIGHTS:.*]] = torch.per_tensor_affine.create %[[INT_REPR]], %num, %num0_i64 : !torch.tensor<[2,5],si8>, f64, i64 -> !torch.tensor<[2,5],!torch.qint8>
|
||||
# CHECK: %[[BIAS:.*]] = torch.tensor({{.*}}) : !torch.tensor<[2],f32>
|
||||
# CHECK: %[[LINEAR_PARAMS:.*]] = torch.linear_params.create %[[WEIGHTS]], %[[BIAS]] : !torch.tensor<[2,5],!torch.qint8>, !torch.tensor<[2],f32>
|
||||
@torch.jit.export
|
||||
def test_linear(self, t):
|
||||
return self.linear(t)
|
||||
|
||||
# CHECK: %[[LINEAR_PARAMS_NO_BIAS:.*]] = torch.linear_params.create %{{.*}} : !numpy.ndarray<*:!numpy.any_dtype>{{$}}
|
||||
# CHECK: %[[LINEAR_PARAMS_NO_BIAS:.*]] = torch.linear_params.create %{{[^,]*}} : !torch.tensor<[2,6],!torch.qint8>
|
||||
@torch.jit.export
|
||||
def test_linear_no_bias(self, t):
|
||||
return self.linear_no_bias(t)
|
||||
|
|
|
@ -15,16 +15,14 @@ class TestModule(torch.nn.Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
# TODO: Test (and make work) tensors that alias each other.
|
||||
self.t = torch.ones(1)
|
||||
self.p = torch.nn.Parameter(torch.arange(3.0))
|
||||
self.ones = torch.ones(1)
|
||||
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
||||
|
||||
# CHECK: %[[CP:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>
|
||||
# CHECK: %[[P:.*]] = numpy.create_array_from_tensor %[[CP]] : (tensor<3xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
# CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: %[[ARANGE:.*]] = torch.tensor(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32>
|
||||
# CHECK: %[[ONES:.*]] = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor<[1],f32>
|
||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||
# CHECK: torch.slot "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: torch.slot "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.tensor<[3],f32>
|
||||
# CHECK: torch.slot "ones", %[[ONES]] : !torch.tensor<[1],f32>
|
||||
# CHECK: }
|
||||
|
||||
|
||||
|
|
|
@ -10,9 +10,9 @@ import torch_mlir
|
|||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @__torch__.f(
|
||||
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !basicpy.ListType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !basicpy.ListType
|
||||
# CHECK: return %[[RET]] : !basicpy.ListType
|
||||
|
||||
@mb.import_function
|
||||
|
|
|
@ -15,9 +15,9 @@ mb = torch_mlir.ModuleBuilder()
|
|||
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.tensor {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : i64 -> !torch.tensor
|
||||
# CHECK: return %[[RET]] : !torch.tensor
|
||||
# CHECK: }
|
||||
|
||||
@mb.import_function
|
||||
|
@ -26,9 +26,9 @@ def prim_NumToTensor(i: int):
|
|||
return _to_tensor(i)
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_Print(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !basicpy.NoneType {
|
||||
# CHECK: %[[STR:.*]] = basicpy.bytes_constant "x"
|
||||
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !torch.tensor
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_Print(x):
|
||||
|
@ -94,8 +94,8 @@ def prim_ListUnpack(l: typing.List[int]):
|
|||
return val
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_dtype(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> i64 {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> i64
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> i64 {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !torch.tensor -> i64
|
||||
# CHECK: return %[[RET]] : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
@ -103,8 +103,8 @@ def prim_dtype(x):
|
|||
return x.dtype
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_layout(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> i64 {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.layout %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> i64
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> i64 {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.layout %[[ARG]] : !torch.tensor -> i64
|
||||
# CHECK: return %[[RET]] : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
@ -112,8 +112,8 @@ def prim_layout(x):
|
|||
return x.layout
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_device(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !torch.Device {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> !torch.Device
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.Device {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !torch.tensor -> !torch.Device
|
||||
# CHECK: return %[[RET]] : !torch.Device
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -10,9 +10,9 @@ import torch_mlir
|
|||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @__torch__.f(
|
||||
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !basicpy.TupleType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !basicpy.TupleType
|
||||
# CHECK: return %[[RET]] : !basicpy.TupleType
|
||||
|
||||
@mb.import_function
|
||||
|
|
|
@ -167,6 +167,58 @@ int npcompTypeIsAQInt8(MlirType t);
|
|||
/** Gets the !torch.qint8 type. */
|
||||
MlirType npcompQInt8TypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.tensor type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.tensor type */
|
||||
int npcompTypeIsANonValueTensor(MlirType t);
|
||||
|
||||
/** Gets a !torch.tensor type.
|
||||
*
|
||||
* - `optionalSizes` is allowed to be null, meaning that no size information is
|
||||
* present (and `numSizes` is ignored in that case).
|
||||
* - `optionalDtype` is allowed to be null, meaning that no dtype information is
|
||||
* present.
|
||||
*
|
||||
*/
|
||||
MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype);
|
||||
|
||||
/** Gets the !torch.tensor type with the least static information. */
|
||||
MlirType
|
||||
npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */
|
||||
MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.vtensor type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.vtensor type */
|
||||
int npcompTypeIsAValueTensor(MlirType t);
|
||||
|
||||
/** Gets a !torch.vtensor type.
|
||||
*
|
||||
* - `optionalSizes` is allowed to be null, meaning that no size information is
|
||||
* present (and `numSizes` is ignored in that case).
|
||||
* - `optionalDtype` is allowed to be null, meaning that no dtype information is
|
||||
* present.
|
||||
*
|
||||
*/
|
||||
MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype);
|
||||
|
||||
/** Gets the !torch.tensor type with the least static information. */
|
||||
MlirType
|
||||
npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */
|
||||
MlirType npcompValueTensorTypeGetFromShaped(MlirType type);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -18,8 +18,6 @@ namespace NPCOMP {
|
|||
namespace Numpy {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createPublicFunctionsToTensorPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createArrayToTensorPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||
|
||||
} // namespace Numpy
|
||||
|
||||
|
|
|
@ -20,44 +20,4 @@ def NumpyPublicFunctionsToTensor : Pass<"numpy-public-functions-to-tensor", "Mod
|
|||
let constructor = "mlir::NPCOMP::Numpy::createPublicFunctionsToTensorPass()";
|
||||
}
|
||||
|
||||
def NumpyArrayToTensor : Pass<"numpy-array-to-tensor", "FuncOp"> {
|
||||
let summary = "Replace arrays with tensors where possible (optimization only).";
|
||||
let description = [{
|
||||
This pass is analogous to an SSA-formation pass in a
|
||||
traditional compiler, with the added complication that arrays can alias
|
||||
each other in interesting ways.
|
||||
|
||||
The current code doesn't implement any fancy algorithm, and is intended
|
||||
to be just sufficient for a first e2e spike. An algorithm inspired by the
|
||||
SSA formation literature will need to be implemented.
|
||||
|
||||
Also, this pass doesn't currently handle interprocedural rewriting
|
||||
(of private functions), which is even more complex.
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::Numpy::createArrayToTensorPass()";
|
||||
}
|
||||
|
||||
|
||||
def NumpyRefinePublicReturn : Pass<"numpy-refine-public-return", "ModuleOp"> {
|
||||
let summary = "Refine public return";
|
||||
let constructor = "mlir::NPCOMP::Numpy::createRefinePublicReturnPass()";
|
||||
let description = [{
|
||||
Refines types of values return from public functions based on
|
||||
intraprocedural information.
|
||||
|
||||
This pass effectively encodes an assumption by the pass pipeline author that
|
||||
the public calling convention of the module can have its types refined,
|
||||
without causing ABI mismatches. This is frequently true -- for example, in
|
||||
many systems, `tensor<?x?xf32>`, `tensor<3x3xf32>` and
|
||||
`tensor<*x!numpy.any_dtype>` are all the same data structure on calling
|
||||
convention boundaries.
|
||||
|
||||
This pass is expected to run after shape refinement has occurred to
|
||||
otherwise resolve shapes, and is currently mainly useful to convert
|
||||
rank/dtype-erased function boundaries to ranked, dtyped code for
|
||||
compiler backends.
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // NPCOMP_NUMPY_PASSES
|
||||
|
|
|
@ -237,6 +237,7 @@ def Torch_AtenDimOp : Torch_Op<"aten.dim", [
|
|||
AnyTorchIntType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenSizeOp : Torch_Op<"aten.size", [
|
||||
|
@ -387,6 +388,7 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
|
|||
AnyTorchIntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.td"
|
||||
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
|
||||
|
||||
def Torch_Dialect : Dialect {
|
||||
let name = "torch";
|
||||
|
@ -19,24 +18,9 @@ def Torch_Dialect : Dialect {
|
|||
let description = [{
|
||||
Top-level dialect for interfacing PyTorch and MLIR.
|
||||
|
||||
This dialect contains types and structural ops that model enough of
|
||||
PyTorch's behavior to allow for easy import/call-out. While not aiming to
|
||||
be completely isomorphic, it is laid out to make conversion in/out
|
||||
systematic for the supported features (some of which are aspirational):
|
||||
- Transitions between mutable and immutable tensors.
|
||||
- Gradient associations and management.
|
||||
- Custom ops.
|
||||
- Types specific to PyTorch such as torch.nn.Module structures
|
||||
- Module level constructs like quantization parameters, globals, etc.
|
||||
This dialect maintains a fairly isomorphic representation with TorchScript.
|
||||
|
||||
Where possible, this dialect composes with types and ops from the `Numpy`
|
||||
and `Basicpy` dialects, and those dialects should be considered "upstream"
|
||||
for basic Python and ND-Array based programming constructs.
|
||||
|
||||
As a key point, this dialect does not contain any custom operations,
|
||||
including those that people would typically associate as core (see
|
||||
the `ATen` dialect for mathematical ops like add, conv, etc), instead
|
||||
modeling the open op-system that PyTorch reasons about natively.
|
||||
TODO: Add more detail here.
|
||||
}];
|
||||
|
||||
let hasRegionArgAttrVerify = 1;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTraits.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
@ -23,6 +24,21 @@
|
|||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Torch {
|
||||
/// Create code to copy `tensor` to type `newType`.
|
||||
///
|
||||
/// This involves two independent steps, which we keep orthogonal in our
|
||||
/// IR representation.
|
||||
/// 1. Adding/removing static information about sizes/dtype.
|
||||
/// 2. Performing the copy, which allows us to add/remove value semantics.
|
||||
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
|
||||
Value tensor);
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
|
||||
using SlotOp = ::mlir::NPCOMP::Torch::SlotOp;
|
||||
static SlotOp getEmptyKey() {
|
||||
|
|
|
@ -14,6 +14,7 @@ include "npcomp/Interfaces/Traits.td"
|
|||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
|
@ -46,7 +47,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
|||
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||
torch.slot "i", %num3_i64 : i64
|
||||
torch.slot "f", %num : f64
|
||||
torch.slot "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "t", %t : !torch.tensor
|
||||
torch.slot "submodule", %1 : !torch.nn.Module
|
||||
} : !torch.nn.Module<"my_class_name">
|
||||
```
|
||||
|
@ -126,7 +127,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
|||
torch.attr "b" : !basicpy.BoolType
|
||||
torch.attr "i" : i64
|
||||
torch.attr "f" : f64
|
||||
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "t" : !torch.tensor
|
||||
torch.attr "submodule" : !torch.nn.Module<"empty">
|
||||
torch.method "method", @f
|
||||
}
|
||||
|
@ -135,7 +136,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
|||
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||
torch.slot "i", %num3_i64 : i64
|
||||
torch.slot "f", %num : f64
|
||||
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "t", %t : !torch.tensor
|
||||
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
|
||||
} : !torch.nn.Module<"test">
|
||||
```
|
||||
|
@ -428,6 +429,8 @@ def Torch_DerefineOp : Torch_Op<"derefine", [
|
|||
|
||||
This op bridges that impedance mismatch. This op allows casting a value
|
||||
from one type to a type that it is a subtype of to model this behavior.
|
||||
This op uses the TorchScript notion of subtype, which matches the
|
||||
Python notion of subtype presented in PEP 483.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTorchType:$operand);
|
||||
|
@ -505,4 +508,165 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
|||
}];
|
||||
}
|
||||
|
||||
// TODO: Disaggregate this op into a value-semantic constant + val->nonval
|
||||
// conversion if needed.
|
||||
// Currently, this op can effectively hide val->nonval conversion, which makes
|
||||
// it an edge case for passes that care about that such as
|
||||
// torch-maximize-value-semantics.
|
||||
// So the suggestion would be to lower this to a `torch.vtensor` op
|
||||
// (+`torch.copy.tensor` if needed).
|
||||
// In particular, currently we end up relying on convert-torch-to-std
|
||||
// to effectively expose this (as part of lowering to `std.constant`) +
|
||||
// hoping that some canonicalization cleans it up.
|
||||
// The `torch-maximize-value-semantics` pass should be doing this
|
||||
// before we convert to std at all.
|
||||
def Torch_TensorOp : Torch_Op<"tensor", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface, ["isCompatibleReturnTypes"]>,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Create a value of !torch.tensor type from a literal";
|
||||
let description = [{
|
||||
Example:
|
||||
```
|
||||
%0 = torch.tensor(dense<0.0> : tensor<3x5xf32>) : !torch.tensor
|
||||
%1 = torch.tensor(dense<0.0> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins ElementsAttr:$value);
|
||||
let results = (outs AnyTorchTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $value `)` attr-dict `:` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// InferTypeOpInterface:
|
||||
static bool isCompatibleReturnTypes(TypeRange inferred, TypeRange actual);
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
AllowsTypeRefinement,
|
||||
NoSideEffect]> {
|
||||
let summary = "Adds/removes static information from a tensor type.";
|
||||
let description = [{
|
||||
This op does not imply any runtime code. Semantically it is an identity
|
||||
function. However, it statically annotates (or erases) shape and dtype
|
||||
information from a tensor type.
|
||||
|
||||
This op *cannot* be used to add/remove value semantics from a tensor.
|
||||
For converting between the value-semantic and non-value-semantic domains,
|
||||
use `torch.copy.tensor`. The two ops are kept separate to prevent
|
||||
canonicalizations from accidentally dropping static information. In
|
||||
most cases, after running the `torch-refine-types` pass, this op becomes
|
||||
a no-op (the pass will incorporate the static information into other ops
|
||||
that allow type refinement).
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$operand
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `to` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", []> {
|
||||
let summary = "Makes a copy of a tensor.";
|
||||
let description = [{
|
||||
Changes to the original tensor will not be reflected in the copy.
|
||||
|
||||
This op can be used to interconvert between value-semantic and
|
||||
non-value-semantic tensors. However, this op *does not* allow
|
||||
adding/removing static information about sizes/dtype. For that, use
|
||||
`torch.tensor_static_info_cast`.
|
||||
|
||||
This op does not have the AllowsTypeRefinement trait because the operand
|
||||
and result types are coupled. Only places that know how to simultaneously
|
||||
update both types should be changing the type of this op.
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$operand
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Ovewrite the contents of tensor with values from another.";
|
||||
let description = [{
|
||||
Replaces the contents of `overwritten` with corresponding values from
|
||||
`value`.
|
||||
|
||||
Immediately after this op has completed, indexing `overwritten` will result
|
||||
in identical values as indexing into `tensor`. Of course, later ops
|
||||
might mutate `overwritten`, so this relationship need not hold for the
|
||||
entire program.
|
||||
|
||||
This op has undefined behavior if the two tensors have different
|
||||
shapes or dtypes.
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$value,
|
||||
AnyTorchTensorType:$overwritten
|
||||
);
|
||||
let results = (outs
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$value `overwrites` $overwritten attr-dict
|
||||
`:` type($value) `,` type($overwritten)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_ToBuiltinTensorOp : Torch_Op<"to_builtin_tensor", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
]> {
|
||||
let summary = "Convert a `!torch.vtensor` to a `tensor`";
|
||||
let description = [{
|
||||
This op only operates on ValueTensorType, to avoid conflating conversions
|
||||
between value-semantic and non-value-semantic types.
|
||||
}];
|
||||
let arguments = (ins
|
||||
Torch_ValueTensorType:$operand
|
||||
);
|
||||
let results = (outs
|
||||
AnyTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_FromBuiltinTensorOp : Torch_Op<"from_builtin_tensor", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
]> {
|
||||
let summary = "Convert a `tensor` to a `!torch.vtensor`";
|
||||
let description = [{
|
||||
This op only operates on ValueTensorType, to avoid conflating conversions
|
||||
between value-semantic and non-value-semantic types.
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyTensor:$operand
|
||||
);
|
||||
let results = (outs
|
||||
Torch_ValueTensorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -11,7 +11,118 @@
|
|||
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Torch {
|
||||
|
||||
class NonValueTensorType;
|
||||
class ValueTensorType;
|
||||
|
||||
/// Common getter function signature that covers all tensor types.
|
||||
/// Used for sharing code between NonValueTensorType and ValueTensorType.
|
||||
using GetTensorTypeFn =
|
||||
llvm::function_ref<Type(MLIRContext *, Optional<ArrayRef<int64_t>>, Type)>;
|
||||
|
||||
/// The representation of an unknown dimension size in an ArrayRef<int64_t>.
|
||||
constexpr static int64_t kUnknownSize = -1;
|
||||
|
||||
class BaseTensorType : public Type {
|
||||
public:
|
||||
using Type::Type;
|
||||
|
||||
/// Get the raw optional list of sizes.
|
||||
///
|
||||
/// It is expected that for many users, `hasSizes`/`getSizes` will be a more
|
||||
/// convenient API.
|
||||
Optional<ArrayRef<int64_t>> getOptionalSizes() const;
|
||||
|
||||
/// Get the raw nullable Type representing the dtype of this tensor type.
|
||||
///
|
||||
/// It is expected that for many users, `hasDtype`/`getDtype` will be a more
|
||||
/// convenient API.
|
||||
Type getOptionalDtype() const;
|
||||
|
||||
/// Return true if this type has a list of sizes.
|
||||
bool hasSizes() const { return getOptionalSizes().hasValue(); }
|
||||
|
||||
/// Get the list of sizes. Requires `hasSizes()`.
|
||||
ArrayRef<int64_t> getSizes() const {
|
||||
assert(hasSizes() && "must have sizes");
|
||||
return getOptionalSizes().getValue();
|
||||
}
|
||||
|
||||
/// Return true if all sizes of this tensor are known.
|
||||
bool areAllSizesKnown() const {
|
||||
return hasSizes() && llvm::all_of(getSizes(), [](int64_t size) {
|
||||
return size != kUnknownSize;
|
||||
});
|
||||
}
|
||||
|
||||
/// Return true if this type has a known dtype.
|
||||
bool hasDtype() const { return static_cast<bool>(getOptionalDtype()); }
|
||||
|
||||
/// Get the dtype. Requires `hasDtype()`.
|
||||
Type getDtype() const {
|
||||
assert(hasDtype() && "must have a dtype");
|
||||
return getOptionalDtype();
|
||||
}
|
||||
|
||||
/// Enable isa/dyn_cast for BaseTensorType.
|
||||
static bool classof(Type type);
|
||||
|
||||
/// Return true if this type has the same sizes and dtype as the other.
|
||||
bool hasSameSizesAndDtype(BaseTensorType other) const;
|
||||
|
||||
/// Return a type of the same kind as this one, but with sizes and dtype
|
||||
/// from `other`.
|
||||
Type getWithSizesAndDtypeFrom(BaseTensorType other) const;
|
||||
|
||||
/// Return a type of the same kind as this one, but with given raw optional
|
||||
/// sizes and raw optional dtype.
|
||||
Type getWithSizesAndDtype(Optional<ArrayRef<int64_t>> optionalSizes,
|
||||
Type optionalDtype) const;
|
||||
|
||||
/// Return a type with the same shape and dtype as this one, but with
|
||||
/// value semantics.
|
||||
ValueTensorType getWithValueSemantics() const;
|
||||
};
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Inline definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Torch {
|
||||
|
||||
inline Optional<ArrayRef<int64_t>> BaseTensorType::getOptionalSizes() const {
|
||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
||||
return tensor.getOptionalSizes();
|
||||
if (auto tensor = dyn_cast<ValueTensorType>())
|
||||
return tensor.getOptionalSizes();
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
inline Type BaseTensorType::getOptionalDtype() const {
|
||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
||||
return tensor.getOptionalDtype();
|
||||
if (auto tensor = dyn_cast<ValueTensorType>())
|
||||
return tensor.getOptionalDtype();
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
inline bool BaseTensorType::classof(Type type) {
|
||||
return type.isa<NonValueTensorType, ValueTensorType>();
|
||||
}
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||
|
|
|
@ -15,7 +15,9 @@ include "npcomp/Dialect/Torch/IR/TorchBase.td"
|
|||
// Type defs
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Torch_Type<string name, string typeMnemonic> : TypeDef<Torch_Dialect, name> {
|
||||
class Torch_Type<string name, string typeMnemonic,
|
||||
string baseCppClass = "::mlir::Type">
|
||||
: TypeDef<Torch_Dialect, name, [], baseCppClass> {
|
||||
let mnemonic = typeMnemonic;
|
||||
}
|
||||
|
||||
|
@ -44,6 +46,148 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
|||
}];
|
||||
}
|
||||
|
||||
// For standard ArrayRefs, which require allocation.
|
||||
class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
|
||||
AttrOrTypeParameter<
|
||||
"::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> {
|
||||
let allocator = [{
|
||||
if ($_self.hasValue()) {
|
||||
$_dst.getValue() = $_allocator.copyInto($_self.getValue());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
class AnyTorchTensorType<string name, string typeMnemonic>
|
||||
: Torch_Type<name, typeMnemonic, "::mlir::NPCOMP::Torch::BaseTensorType"> {
|
||||
let summary = "Multi-dimensional array modeling Torch's Tensor type";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
tensor-type ::= (`!torch.tensor` | `!torch.vtensor`) tensor-modifiers?
|
||||
tensor-modifiers ::= `<` sizes-spec `,` dtype-spec `>`
|
||||
sizes-spec ::= `*` | `[` size-list `]`
|
||||
size-list ::= /*empty*/ | size-list-nonempty
|
||||
size-list-nonempty = size (`,` size)*
|
||||
size ::= `?` | decimal-literal
|
||||
dtype-spec ::= `unk` | type
|
||||
```
|
||||
|
||||
Represents a multi-dimensional array to model Torch's `torch.Tensor` type.
|
||||
|
||||
If the type is `!torch.tensor`, it represents a general unrestricted
|
||||
`torch.Tensor`, including potential mutability, aliasing, etc.
|
||||
If the type is `!torch.vtensor` then the tensor is restricted to operations
|
||||
that have value semantics ("v" = "value semantics"). This helps to maintain
|
||||
a strict separation between the value-semantic and potentially-mutating
|
||||
worlds, as one of our main jobs in the compiler is to isolate the mutating
|
||||
parts as much as possible because most lower levels of the compiler stack
|
||||
are expected to require value semantics. E.g. npcomp's backend contract
|
||||
is mostly in terms of linalg-on-tensor for compute-heavy ops, which require
|
||||
a conversion to the builtin `tensor` type which has value semantics.
|
||||
Some notes about value semantics:
|
||||
- Using the type system described in PEP 483 (which TorchScript and other
|
||||
Python systems follow), `!torch.tensor` is a subtype of
|
||||
`!torch.vtensor`. Specifically, both types have the same set of values,
|
||||
but `!torch.tensor` additionally allows aliasing or mutating
|
||||
operations.
|
||||
- Despite being a subtype, a `!torch.tensor` carries *less* static
|
||||
information than a corresponding `!torch.vtensor`. In particular,
|
||||
`!torch.vtensor` carries the static information "not used in aliasing
|
||||
or mutating operations".
|
||||
- `!torch.vtensor` can be trivially converted to the builtin `tensor`
|
||||
type when the dtype is known (the builtin `tensor` type does not allow
|
||||
an unknown dtype).
|
||||
|
||||
In the absence of the `tensor-modifiers`, the type contains the minimal
|
||||
amount of static information. That is, `!torch.tensor` is equivalent to
|
||||
`!torch.tensor<*,unk>` (and similarly for `!torch.vtensor`).
|
||||
|
||||
If `sizes-spec` is not `*`, it indicates additional static information
|
||||
about the sizes of the tensor. It will consist of a list of elements,
|
||||
with length equal to the "rank" (in MLIR parlance) or "ndim"
|
||||
(in Torch parlance). Each element represents a size, with the typical
|
||||
MLIR representation of a number for a statically known size and `?` for a
|
||||
size that is unknown. Thus, the lattice consists of `*` as the least static
|
||||
information, followed by lists containing unknown sizes such as `[?,?,?]`
|
||||
which contribute rank information, followed by statically specified sizes
|
||||
for some dimensions such as `[?,3,?]`, followed by fully statically
|
||||
specified sizes such as `[2,3,4]`.
|
||||
|
||||
If `dtype-spec` is not `unk` ("unknown"), it contains an MLIR type
|
||||
which contributes static information about the dtype of the tensor.
|
||||
Only types allowed by Torch are permitted.
|
||||
```
|
||||
|-------------------|--------------------|
|
||||
| Torch Type | MLIR Type |
|
||||
|-------------------|--------------------|
|
||||
| torch.float16 | f16 |
|
||||
| torch.bfloat16 | bf16 |
|
||||
| torch.float32 | f32 |
|
||||
| torch.float64 | f64 |
|
||||
| torch.uint8 | ui8 |
|
||||
| torch.int8 | si8 |
|
||||
| torch.int16 | si16 |
|
||||
| torch.int32 | si32 |
|
||||
| torch.int64 | si64 |
|
||||
| torch.bool | i1 |
|
||||
| torch.qint8 | !torch.qint8 |
|
||||
|-------------------|--------------------|
|
||||
```
|
||||
|
||||
TODO: Support the full set of Torch dtypes.
|
||||
TODO: Use si1?
|
||||
|
||||
Note: We avoid the C++ identifier `TensorType` to avoid C++ name ambiguities
|
||||
with `mlir::TensorType`, since most code is transitively nested in
|
||||
both `::mlir` and `::mlir::NPCOMP::Torch` namespaces.
|
||||
|
||||
Note: We use the Torch-aligned terminology "sizes" and "dtype" instead of
|
||||
the MLIR-aligned terminology "rank/shape" and "element type". The cheat
|
||||
sheet is:
|
||||
- `hasRank()` -> `hasSizes()`
|
||||
- `getShape()` -> `getSizes()`
|
||||
- `getElementType()` -> `getDtype()` (but be sure that `hasDtype()` though).
|
||||
}];
|
||||
let parameters = (ins
|
||||
OptionalArrayRefParameter<"int64_t", "sizes of dimensions">:$optionalSizes,
|
||||
"::mlir::Type":$optionalDtype
|
||||
);
|
||||
let genVerifyDecl = 1;
|
||||
string extraBaseClassDeclaration = [{
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_NonValueTensorType : AnyTorchTensorType<"NonValueTensor", "tensor"> {
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
// Get this type, with value semantics added.
|
||||
ValueTensorType getWithValueSemantics() const;
|
||||
// Get the !torch.tensor type with the least static information.
|
||||
static NonValueTensorType getWithLeastStaticInformation(MLIRContext *context);
|
||||
// Get a NonValueTensorType with shape/dtype matching `type`.
|
||||
static NonValueTensorType getFromShaped(ShapedType type);
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
// Get this type, with value semantics removed.
|
||||
NonValueTensorType getWithoutValueSemantics() const;
|
||||
// Get the !torch.tensor type with the least static information.
|
||||
static ValueTensorType getWithLeastStaticInformation(MLIRContext *context);
|
||||
// Get a NonValueTensorType with shape/dtype matching `type`.
|
||||
static ValueTensorType getFromShaped(ShapedType type);
|
||||
// Get the builtin tensor type with the same static information as this one,
|
||||
// or nullptr if that is not possible (i.e. when the dtype is unknown).
|
||||
TensorType toBuiltinTensor() const;
|
||||
}];
|
||||
}
|
||||
|
||||
def AnyTorchTensorType : Type<
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Torch::BaseTensorType>()">,
|
||||
"Any Torch tensor type"
|
||||
>;
|
||||
|
||||
// TODO: It feels like this should be something more general.
|
||||
// However, to do that, we need to agree on construction operations
|
||||
// and the valid MLIR representations of the "None" state.
|
||||
|
@ -118,51 +262,6 @@ def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
|
|||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Torch has a fairly advanced and featureful Tensor type, and some of the
|
||||
// semantics are important to preserve in a compilation context. In the future,
|
||||
// a dedicated TorchTensor type may be introduced, but also, subsets of cases
|
||||
// and interop are well served by existing tensor-like types, which are
|
||||
// specifically permitted. Typically, on import, constraints are fairly loose
|
||||
// and based on how the program is captured. Settling on and refining to
|
||||
// specific types is done as part of lowering.
|
||||
//
|
||||
// While lowering it is useful to be able to distinguish between mutable and
|
||||
// immutable tensors:
|
||||
// - Mutable tensors can alias.
|
||||
// - Mutable tensors can be a view over another mutable tensor.
|
||||
// - Mutable tensors act as if reference counted and exist for the lifetime
|
||||
// of any reference or derived view.
|
||||
// Conversely, immutable tensors:
|
||||
// - Are normal SSA values representing the contents of the tensor.
|
||||
// - Cannot alias.
|
||||
// - Cannot be a view of any mutable value.
|
||||
// - Have undefined lifetimes.
|
||||
//
|
||||
// At the Torch dialect level, most things are modeled as an AnyTorchTensor;
|
||||
// however, when lowering to specific ops, further constraints are introduced,
|
||||
// necessitating copies, loads, and stores to be inserted to bridge worlds.
|
||||
def AnyTorchImmutableTensor : AnyTypeOf<[
|
||||
// Normal MLIR immutable tensors.
|
||||
AnyTensor,
|
||||
], "allowable torch immutable tensor">;
|
||||
|
||||
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
|
||||
AnyTorchImmutableTensor,
|
||||
Basicpy_NoneType,
|
||||
], "allowable torch immutable tensor (or None)">;
|
||||
|
||||
def AnyTorchMutableTensor : AnyTypeOf<[
|
||||
// "Numpy-style" mutable NDArray. While not offering the full generality
|
||||
// of a Torch tensor, it models the same access patterns and implies the
|
||||
// same aliasing as Torch tensors.
|
||||
Numpy_NdArrayType,
|
||||
], "allowable torch mutable tensor">;
|
||||
|
||||
def AnyTorchTensorType : AnyTypeOf<[
|
||||
AnyTorchImmutableTensor,
|
||||
AnyTorchMutableTensor,
|
||||
], "Any tensor type legal to pass to a Torch kernel">;
|
||||
|
||||
def AnyTorchOptionalTensor : AnyTypeOf<[
|
||||
AnyTorchTensorType,
|
||||
Torch_OptionalType,
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHUTILS_H
|
||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHUTILS_H
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Torch {
|
||||
void setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
||||
TypeConverter &typeConverter);
|
||||
}
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHUTILS_H
|
|
@ -51,6 +51,13 @@ std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
|||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createReduceOpVariantsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBuiltinTensorizePass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createFinalizingBuiltinTensorizePass();
|
||||
|
||||
} // namespace Torch
|
||||
|
||||
|
|
|
@ -164,4 +164,74 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> {
|
||||
let summary = "Use value-semantic tensors where possible.";
|
||||
let description = [{
|
||||
Use value-semantic tensors where possible to make the program more
|
||||
analyzable by later passes (also, backends prefer value semantics as well).
|
||||
|
||||
This pass is analogous to an SSA-formation pass in a
|
||||
traditional compiler, with the added complication that arrays can alias
|
||||
each other in interesting ways.
|
||||
|
||||
The current code doesn't implement any fancy algorithm, and is intended
|
||||
to be just sufficient for a first e2e spike. An algorithm inspired by the
|
||||
SSA formation literature will need to be implemented.
|
||||
|
||||
Also, this pass doesn't currently handle interprocedural rewriting
|
||||
(of private functions), which is even more complex.
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass()";
|
||||
}
|
||||
|
||||
|
||||
def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
||||
let summary = "Refine public return";
|
||||
let constructor = "mlir::NPCOMP::Torch::createRefinePublicReturnPass()";
|
||||
let description = [{
|
||||
Refines types of values returned from public functions based on
|
||||
intraprocedural information.
|
||||
|
||||
This pass effectively encodes an assumption by the pass pipeline author that
|
||||
the public calling convention of the module can have its types refined,
|
||||
without causing ABI mismatches. This is frequently true -- for example, in
|
||||
many systems, `!torch.vtensor<[?,?],f32>`, `!torch.vtensor<[3,3],f32>` and
|
||||
`!torch.vtensor` are all the same data structure on calling
|
||||
convention boundaries.
|
||||
|
||||
This pass is expected to run after shape refinement has occurred to
|
||||
otherwise resolve shapes, and is currently mainly useful to convert
|
||||
rank/dtype-erased function boundaries to ranked, dtyped code for
|
||||
compiler backends.
|
||||
|
||||
This pass also changes the return to be a value tensor. This is incorrect
|
||||
in general because users may rely on the aliasing properties of non-value
|
||||
tensors, but for now it is deemed expedient to include this in this pass.
|
||||
TODO: Avoid hardcoding the value tensor assumption. In general, much
|
||||
as the type bound of an argument can be marked as having value semantics
|
||||
at the frontend level based on user concerns, so too should the returns
|
||||
from the function be annotated as having value semantics.
|
||||
}];
|
||||
}
|
||||
|
||||
def FuncBuiltinTensorize : Pass<"torch-func-builtin-tensorize", "ModuleOp"> {
|
||||
let summary = "Convert functions to operate on builtin tensors";
|
||||
let constructor = "mlir::NPCOMP::Torch::createFuncBuiltinTensorizePass()";
|
||||
let description = [{
|
||||
Partial type conversion pass analogous in scope to the upstream
|
||||
`func-bufferize` pass. See details there.
|
||||
}];
|
||||
}
|
||||
|
||||
def FinalizingBuiltinTensorize
|
||||
: Pass<"torch-finalizing-builtin-tensorize", "FuncOp"> {
|
||||
let summary = "Finalizes a partial conversion to builtin tensors";
|
||||
let constructor =
|
||||
"mlir::NPCOMP::Torch::createFinalizingBuiltinTensorizePass()";
|
||||
let description = [{
|
||||
Analogous in scope to the upstream `finalizing-bufferize` pass.
|
||||
See details there.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_TORCH_PASSES
|
||||
|
|
|
@ -216,3 +216,61 @@ int npcompTypeIsAQInt8(MlirType t) {
|
|||
MlirType npcompQInt8TypeGet(MlirContext context) {
|
||||
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.tensor type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsANonValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NonValueTensorType>();
|
||||
}
|
||||
|
||||
MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::NonValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
||||
MlirType
|
||||
npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
||||
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
|
||||
unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type) {
|
||||
return wrap(Torch::NonValueTensorType::getFromShaped(
|
||||
unwrap(type).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.vtensor type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsAValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ValueTensorType>();
|
||||
}
|
||||
|
||||
MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::ValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
||||
MlirType
|
||||
npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
||||
return wrap(
|
||||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType npcompValueTensorTypeGetFromShaped(MlirType type) {
|
||||
return wrap(
|
||||
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h" // TODO: For `memref.dim`.
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
@ -40,15 +40,21 @@ using namespace mlir::NPCOMP::Torch;
|
|||
// that these patterns become mostly mechanical associations of
|
||||
// "aten.foo -> linalg.foo".
|
||||
|
||||
static LogicalResult verifyLinalgCompatibleTypes(Operation *op, PatternRewriter &rewriter) {
|
||||
static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||
PatternRewriter &rewriter) {
|
||||
// For now, use a small allowlist of types we don't reject.
|
||||
// The main culprit in practice is that !numpy.any_dtype might be present
|
||||
// if shape/dtype inference wasn't good enough.
|
||||
// The main culprit in practice is an unknown dtype
|
||||
// when RefineTypes isn't smart enough to propagate it everywhere.
|
||||
// For tensors, we consider the post-conversion tensor type (this pass is
|
||||
// doing a type conversion).
|
||||
auto isValidLinalgType = [](Type type) {
|
||||
if (auto rankedTensor = type.dyn_cast<RankedTensorType>()) {
|
||||
if (auto tensor = type.dyn_cast<ValueTensorType>()) {
|
||||
if (auto rankedTensor =
|
||||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>()) {
|
||||
if (BaseMemRefType::isValidElementType(rankedTensor.getElementType()))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (type.isa<FloatType, IntegerType, IndexType>())
|
||||
return true;
|
||||
return false;
|
||||
|
@ -60,15 +66,21 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op, PatternRewriter
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult convertMmOp(AtenMmOp op, PatternRewriter &rewriter) {
|
||||
namespace {
|
||||
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenMmOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value lhs = op.getOperand(0);
|
||||
Value rhs = op.getOperand(1);
|
||||
Value lhs = operands[0];
|
||||
Value rhs = operands[1];
|
||||
|
||||
// A user can write an errorneous program where `aten.mm` is in fact called
|
||||
// with operands of invalid rank or dtype. We cannot convert to linalg in this
|
||||
// case or we will get a verifier error, which corresponds to breaking of
|
||||
// *internal* compiler invariants, and for a user manifests as a compiler
|
||||
// with operands of invalid rank or dtype. We cannot convert to linalg in
|
||||
// this case or we will get a verifier error, which corresponds to breaking
|
||||
// of *internal* compiler invariants, and for a user manifests as a compiler
|
||||
// crash in the worst case (such as we try to canonicalize/fold/print the
|
||||
// invalid op before the verifier gets to see it -- also release builds of a
|
||||
// mature copmiler usually have the verifier turned off for compile time
|
||||
|
@ -94,10 +106,12 @@ LogicalResult convertMmOp(AtenMmOp op, PatternRewriter &rewriter) {
|
|||
rewriter.getStringAttr(
|
||||
"mismatching contracting dimension for torch.aten.mm"));
|
||||
|
||||
Type elementType = op.getType().cast<TensorType>().getElementType();
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||
Value c0 = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||
Value c0 =
|
||||
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||
Value zeroFill =
|
||||
rewriter.create<linalg::FillOp>(loc, initTensor, c0).getResult(0);
|
||||
Value matmul = rewriter
|
||||
|
@ -106,22 +120,31 @@ LogicalResult convertMmOp(AtenMmOp op, PatternRewriter &rewriter) {
|
|||
.getResult(0);
|
||||
// When constructed with just dynamic sizes, InitTensorOp will have a result
|
||||
// type which has all `?`'s for dimensions, which might not be the result
|
||||
// type of `op`. The constraints on later linalg ops means that the result of
|
||||
// the MatmulOp will have this type too. So cast it to the desired type so
|
||||
// that in the end we have the original result type.
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), matmul);
|
||||
// type of `op`. The constraints on later linalg ops means that the result
|
||||
// of the MatmulOp will have this type too. So cast it to the desired type
|
||||
// so that in the end we have the original result type.
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
|
||||
return success();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// See comments at in convertMmOp and the heading for this section for general
|
||||
// considerations. This function needs to be auto-generated.
|
||||
LogicalResult convertLinearOp(AtenLinearOp op, PatternRewriter &rewriter) {
|
||||
class ConvertAtenLinearOp : public OpConversionPattern<AtenLinearOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenLinearOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
AtenLinearOp::Adaptor adaptor(operands);
|
||||
MLIRContext *context = op->getContext();
|
||||
Location loc = op->getLoc();
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
Value bias = op.bias();
|
||||
Value input = adaptor.input();
|
||||
Value weight = adaptor.weight();
|
||||
Value bias = adaptor.bias();
|
||||
// TODO: Handle the case of bias being None (bias is optional).
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
@ -181,7 +204,8 @@ LogicalResult convertLinearOp(AtenLinearOp op, PatternRewriter &rewriter) {
|
|||
/*dimCount=*/2, /*symbolCount=*/0, rewriter.getAffineDimExpr(1)),
|
||||
rewriter.getMultiDimIdentityMap(2)};
|
||||
SmallVector<StringRef> iteratorTypes(2, "parallel");
|
||||
Value broadcasted = rewriter
|
||||
Value broadcasted =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), bias, initTensor,
|
||||
/*indexingMaps=*/broadcastIndexingMaps,
|
||||
|
@ -192,8 +216,8 @@ LogicalResult convertLinearOp(AtenLinearOp op, PatternRewriter &rewriter) {
|
|||
.getResult(0);
|
||||
// We need a matmul with dimension ordering (N, K) * (M, K), so transpose
|
||||
// the weights to fit into linalg::MatmulOp which is (N, K) * (K, M).
|
||||
// TODO: This whole aten.linear lowering should eventually be generated from a
|
||||
// single linalg ODS generator statement. Both the bias and matmul part.
|
||||
// TODO: This whole aten.linear lowering should eventually be generated from
|
||||
// a single linalg ODS generator statement. Both the bias and matmul part.
|
||||
SmallVector<AffineMap> transposeIndexingMaps = {
|
||||
AffineMap::get(
|
||||
/*dimCount=*/2, /*symbolCount=*/0,
|
||||
|
@ -213,12 +237,17 @@ LogicalResult convertLinearOp(AtenLinearOp op, PatternRewriter &rewriter) {
|
|||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
Value matmul = rewriter.create<linalg::MatmulOp>(
|
||||
loc, broadcasted.getType(), ValueRange{input, transposedWeights},
|
||||
broadcasted).getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), matmul);
|
||||
Value matmul = rewriter
|
||||
.create<linalg::MatmulOp>(
|
||||
loc, broadcasted.getType(),
|
||||
ValueRange{input, transposedWeights}, broadcasted)
|
||||
.getResult(0);
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Converts a unary op. There is no implicit broadcasting behavior, so these can
|
||||
|
@ -227,20 +256,24 @@ namespace {
|
|||
// N-ary broadcasting and allows us to do multiversioning techniques for
|
||||
// lowering to linalg. We can trivially handle this as through that
|
||||
// abstraction instead.
|
||||
struct ConvertUnaryOp : RewritePattern {
|
||||
ConvertUnaryOp(MLIRContext *context)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
||||
struct ConvertUnaryOp : ConversionPattern {
|
||||
ConvertUnaryOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a unary op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Value operand = op->getOperand(0);
|
||||
auto type = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
Value operand = operands[0];
|
||||
auto type = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto rank = type.getRank();
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(rank, "parallel");
|
||||
|
@ -277,19 +310,31 @@ public:
|
|||
registry.insert<linalg::LinalgDialect>();
|
||||
registry.insert<memref::MemRefDialect>();
|
||||
registry.insert<math::MathDialect>();
|
||||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||
}
|
||||
|
||||
FrozenRewritePatternSet getPatterns() {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
memref::MemRefDialect, math::MathDialect,
|
||||
tensor::TensorDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add(convertMmOp);
|
||||
patterns.add(convertLinearOp);
|
||||
patterns.add<ConvertUnaryOp>(context);
|
||||
return std::move(patterns);
|
||||
target.addIllegalOp<AtenMmOp>();
|
||||
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenLinearOp>();
|
||||
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTanhOp>();
|
||||
patterns.add<ConvertUnaryOp>(typeConverter, context);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -11,11 +11,12 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
@ -24,17 +25,23 @@ using namespace mlir::NPCOMP::Torch;
|
|||
// -----------------------------------------------------------------------------
|
||||
// Patterns (as this grows, it should be organized into multiple files)
|
||||
// -----------------------------------------------------------------------------
|
||||
// This is going to eventually be O(#aten ops), which is in the 100s.
|
||||
// This is going to eventually be O(#torch operators), which is in the 100s.
|
||||
|
||||
namespace {
|
||||
// Note: Confusingly, ATen's "dim" means "number of dimensions" which is what
|
||||
// MLIR calls "rank".
|
||||
LogicalResult convertDimOp(AtenDimOp op, PatternRewriter &rewriter) {
|
||||
if (!op.getOperand().getType().isa<TensorType>())
|
||||
return rewriter.notifyMatchFailure(op, "must be tensor only");
|
||||
auto rank = rewriter.create<RankOp>(op->getLoc(), op.getOperand());
|
||||
class ConvertAtenDimOp : public OpConversionPattern<AtenDimOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenDimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rank = rewriter.create<RankOp>(op->getLoc(), operands[0]);
|
||||
rewriter.replaceOpWithNewOp<IndexCastOp>(op, op.getType(), rank);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult convertNeIntOp(AtenNeIntOp op, PatternRewriter &rewriter) {
|
||||
auto i1 = rewriter.create<CmpIOp>(op->getLoc(), CmpIPredicate::ne,
|
||||
|
@ -50,6 +57,15 @@ LogicalResult convertGtIntOp(AtenGtIntOp op, PatternRewriter &rewriter) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult convertTensorOp(TensorOp op, PatternRewriter &rewriter) {
|
||||
auto constant = rewriter.create<ConstantOp>(op->getLoc(), op.value());
|
||||
auto vtensor = rewriter.create<FromBuiltinTensorOp>(op->getLoc(), constant);
|
||||
Value result = copyTensorToType(rewriter, op->getLoc(),
|
||||
op.getType().cast<BaseTensorType>(), vtensor);
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -62,16 +78,27 @@ public:
|
|||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||
}
|
||||
|
||||
FrozenRewritePatternSet getPatterns() {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect, StandardOpsDialect,
|
||||
Basicpy::BasicpyDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add(convertDimOp);
|
||||
target.addIllegalOp<AtenDimOp>();
|
||||
patterns.add<ConvertAtenDimOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNeIntOp>();
|
||||
patterns.add(convertNeIntOp);
|
||||
target.addIllegalOp<AtenGtIntOp>();
|
||||
patterns.add(convertGtIntOp);
|
||||
return std::move(patterns);
|
||||
target.addIllegalOp<TensorOp>();
|
||||
patterns.add(convertTensorOp);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
add_npcomp_conversion_library(NPCOMPNumpyPasses
|
||||
ArrayToTensor.cpp
|
||||
Passes.cpp
|
||||
PublicFunctionToTensor.cpp
|
||||
RefinePublicReturn.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy/Transforms
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_npcomp_dialect_library(NPCOMPTorchDialect
|
||||
TorchDialect.cpp
|
||||
TorchOps.cpp
|
||||
TorchUtils.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch
|
||||
|
@ -18,6 +19,5 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
|
|||
MLIRControlFlowInterfaces
|
||||
MLIRSideEffectInterfaces
|
||||
NPCOMPBasicpyDialect
|
||||
NPCOMPNumpyDialect
|
||||
NPCOMPInterfaces
|
||||
)
|
||||
|
|
|
@ -13,13 +13,13 @@
|
|||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -79,6 +79,266 @@ void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
|||
llvm_unreachable("unknown 'torch' type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BaseTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: Move most of this to a new file TorchTypes.cpp.
|
||||
|
||||
static bool isValidTorchDtype(Type dtype) {
|
||||
// Torch quantized types.
|
||||
if (dtype.isa<Torch::QInt8Type>())
|
||||
return true;
|
||||
// Builtin floating point types.
|
||||
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
|
||||
return true;
|
||||
// Builtin integer types.
|
||||
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
|
||||
if (type.isSignless() && type.getWidth() == 1)
|
||||
return true;
|
||||
if (type.isSigned()) {
|
||||
for (unsigned width : {8, 16, 32, 64}) {
|
||||
if (type.getWidth() == width)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (type.isUnsigned()) {
|
||||
return type.getWidth() == 8;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
|
||||
return getOptionalSizes() == other.getOptionalSizes() &&
|
||||
getOptionalDtype() == other.getOptionalDtype();
|
||||
}
|
||||
|
||||
Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
|
||||
return getWithSizesAndDtype(other.getOptionalSizes(),
|
||||
other.getOptionalDtype());
|
||||
}
|
||||
|
||||
Type BaseTensorType::getWithSizesAndDtype(
|
||||
Optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
|
||||
if (isa<NonValueTensorType>())
|
||||
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
||||
if (isa<ValueTensorType>())
|
||||
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
ValueTensorType BaseTensorType::getWithValueSemantics() const {
|
||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
||||
return tensor.getWithValueSemantics();
|
||||
if (auto tensor = dyn_cast<ValueTensorType>())
|
||||
return tensor;
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
|
||||
Optional<ArrayRef<int64_t>> optionalSizes,
|
||||
Type optionalDtype) {
|
||||
if (optionalDtype && !isValidTorchDtype(optionalDtype)) {
|
||||
emitError() << "invalid dtype " << optionalDtype
|
||||
<< " for !torch.tensor type";
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
|
||||
GetTensorTypeFn getTensorType) {
|
||||
llvm::SMLoc startLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOptionalLess())
|
||||
return getTensorType(context,
|
||||
/*optionalSizes=*/None, /*optionalDtype=*/Type());
|
||||
bool hasSizes;
|
||||
SmallVector<int64_t> sizes;
|
||||
if (succeeded(parser.parseOptionalStar())) {
|
||||
// Unranked.
|
||||
hasSizes = false;
|
||||
} else {
|
||||
// Parse list of sizes.
|
||||
hasSizes = true;
|
||||
if (parser.parseLSquare())
|
||||
return Type();
|
||||
for (bool first = true;; first = false) {
|
||||
if (!first) {
|
||||
if (failed(parser.parseOptionalComma())) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (succeeded(parser.parseOptionalQuestion())) {
|
||||
sizes.push_back(-1);
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t size;
|
||||
auto optionalInt = parser.parseOptionalInteger(size);
|
||||
if (optionalInt.hasValue()) {
|
||||
if (failed(*optionalInt))
|
||||
return Type();
|
||||
sizes.push_back(size);
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (parser.parseRSquare()) {
|
||||
return Type();
|
||||
}
|
||||
}
|
||||
if (parser.parseComma())
|
||||
return Type();
|
||||
Type optionalDtype;
|
||||
if (succeeded(parser.parseOptionalKeyword("unk"))) {
|
||||
// Unknown dtype.
|
||||
} else {
|
||||
// Known dtype.
|
||||
if (parser.parseType(optionalDtype))
|
||||
return Type();
|
||||
}
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
Optional<ArrayRef<int64_t>> optionalSizes;
|
||||
if (hasSizes)
|
||||
optionalSizes.emplace(sizes);
|
||||
|
||||
if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); },
|
||||
optionalSizes, optionalDtype)))
|
||||
return Type();
|
||||
|
||||
return getTensorType(context, optionalSizes, optionalDtype);
|
||||
}
|
||||
|
||||
static void printTensorType(DialectAsmPrinter &printer,
|
||||
Optional<ArrayRef<int64_t>> optionalSizes,
|
||||
Type optionalDtype) {
|
||||
if (!optionalSizes && !optionalDtype)
|
||||
return;
|
||||
printer << "<";
|
||||
if (optionalSizes) {
|
||||
printer << "[";
|
||||
for (auto it : llvm::enumerate(*optionalSizes)) {
|
||||
if (it.index() > 0)
|
||||
printer << ",";
|
||||
if (it.value() < 0)
|
||||
printer << "?";
|
||||
else
|
||||
printer << it.value();
|
||||
}
|
||||
printer << "]";
|
||||
} else {
|
||||
printer << "*";
|
||||
}
|
||||
printer << ",";
|
||||
if (optionalDtype)
|
||||
printer.printType(optionalDtype);
|
||||
else
|
||||
printer << "unk";
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NonValueTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ValueTensorType NonValueTensorType::getWithValueSemantics() const {
|
||||
return ValueTensorType::get(getContext(), getOptionalSizes(),
|
||||
getOptionalDtype());
|
||||
}
|
||||
|
||||
NonValueTensorType
|
||||
NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
||||
return NonValueTensorType::get(context,
|
||||
/*optionalSizes=*/None,
|
||||
/*optionalDtype=*/Type());
|
||||
}
|
||||
|
||||
NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) {
|
||||
return NonValueTensorType::get(type.getContext(),
|
||||
type.hasRank() ? type.getShape()
|
||||
: Optional<ArrayRef<int64_t>>(),
|
||||
type.getElementType());
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
NonValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
Optional<ArrayRef<int64_t>> optionalSizes,
|
||||
Type optionalDtype) {
|
||||
return verifyTensorType(emitError, optionalSizes, optionalDtype);
|
||||
}
|
||||
|
||||
Type NonValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
|
||||
return parseTensorType(
|
||||
context, parser,
|
||||
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
|
||||
Type optionalType) {
|
||||
return NonValueTensorType::get(context, optionalSizes, optionalType);
|
||||
});
|
||||
}
|
||||
|
||||
void NonValueTensorType::print(DialectAsmPrinter &printer) const {
|
||||
printer << "tensor";
|
||||
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
NonValueTensorType ValueTensorType::getWithoutValueSemantics() const {
|
||||
return NonValueTensorType::get(getContext(), getOptionalSizes(),
|
||||
getOptionalDtype());
|
||||
}
|
||||
|
||||
ValueTensorType
|
||||
ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
||||
return ValueTensorType::get(context,
|
||||
/*optionalSizes=*/None,
|
||||
/*optionalDtype=*/Type());
|
||||
}
|
||||
|
||||
ValueTensorType ValueTensorType::getFromShaped(ShapedType type) {
|
||||
return ValueTensorType::get(type.getContext(),
|
||||
type.hasRank() ? type.getShape()
|
||||
: Optional<ArrayRef<int64_t>>(),
|
||||
type.getElementType());
|
||||
}
|
||||
|
||||
TensorType ValueTensorType::toBuiltinTensor() const {
|
||||
if (!hasDtype())
|
||||
return nullptr;
|
||||
if (!hasSizes())
|
||||
return UnrankedTensorType::get(getDtype());
|
||||
return RankedTensorType::get(getSizes(), getDtype());
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
Optional<ArrayRef<int64_t>> optionalSizes,
|
||||
Type optionalDtype) {
|
||||
return verifyTensorType(emitError, optionalSizes, optionalDtype);
|
||||
}
|
||||
|
||||
Type ValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
|
||||
return parseTensorType(
|
||||
context, parser,
|
||||
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
|
||||
Type optionalType) {
|
||||
return ValueTensorType::get(context, optionalSizes, optionalType);
|
||||
});
|
||||
}
|
||||
|
||||
void ValueTensorType::print(DialectAsmPrinter &printer) const {
|
||||
printer << "vtensor";
|
||||
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect-level verifiers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
||||
unsigned regionIndex,
|
||||
unsigned argIndex,
|
||||
|
@ -90,13 +350,13 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
|||
TypeAttr attr = namedAttr.second.dyn_cast<TypeAttr>();
|
||||
if (!attr)
|
||||
return op->emitError() << "'torch.type_bound' must be TypeAttr";
|
||||
auto type = attr.getValue().dyn_cast<Numpy::NdArrayType>();
|
||||
auto type = attr.getValue().dyn_cast<BaseTensorType>();
|
||||
if (!type)
|
||||
return op->emitError()
|
||||
<< "'torch.type_bound' must be of !numpy.ndarray type";
|
||||
if (!func.getType().getInput(argIndex).isa<Numpy::NdArrayType>())
|
||||
return op->emitError() << "'torch.type_bound' must be of "
|
||||
"!torch.tensor/!torch.vtensor type";
|
||||
if (!func.getType().getInput(argIndex).isa<BaseTensorType>())
|
||||
return op->emitError() << "'torch.type_bound' must be attached to an "
|
||||
"argument of !numpy.ndarray type";
|
||||
"argument of !torch.tensor/!torch.vtensor type";
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -104,6 +364,10 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
|||
<< "'";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constant materializer.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
|
@ -113,5 +377,15 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
|||
if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1)
|
||||
return builder.create<Basicpy::BoolConstantOp>(loc, type, i1Value);
|
||||
}
|
||||
// i64 is how we model TorchScript's "scalar integer type" (we could have a
|
||||
// proper !torch.int type in theory). None of our canonicalizers should be
|
||||
// creating any other integer type (except perhaps i1 after we resolve that
|
||||
// situation). All other integer types live inside tensors (that is, they are
|
||||
// never the direct result of an operation, and are thus never candidates for
|
||||
// constant materialization).
|
||||
if (auto integerType = type.dyn_cast<IntegerType>()) {
|
||||
if (integerType.getWidth() == 64)
|
||||
return builder.create<ConstantOp>(loc, value);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -12,16 +12,36 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||
BaseTensorType newType,
|
||||
Value tensor) {
|
||||
auto originalType = tensor.getType().cast<BaseTensorType>();
|
||||
// Adjust the static information in the type to match between the original and
|
||||
// new types.
|
||||
if (!originalType.hasSameSizesAndDtype(newType)) {
|
||||
tensor = builder.create<TensorStaticInfoCastOp>(
|
||||
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
|
||||
}
|
||||
// If both the original and new types already have value semantics, a copy is
|
||||
// pointless.
|
||||
if (originalType.isa<ValueTensorType>() && newType.isa<ValueTensorType>())
|
||||
return tensor;
|
||||
return builder.create<CopyTensorOp>(loc, newType, tensor);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MethodOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -65,12 +85,20 @@ static LogicalResult verify(NnModuleOp op) {
|
|||
// This is a restricted subset of it.
|
||||
//
|
||||
// TODO: Flesh this out.
|
||||
// TODO: Decide / properly model the distinction between PEP 483 / Python
|
||||
// subtyping vs "more static information".
|
||||
bool isValidSubtype(Type subtype, Type type) {
|
||||
if (subtype == type)
|
||||
return true;
|
||||
if (auto optional = type.dyn_cast<OptionalType>())
|
||||
return subtype == optional.getContainedType() ||
|
||||
subtype.isa<Basicpy::NoneType>();
|
||||
// TODO: This is not subtyping according to PEP 483. See description
|
||||
// of NonValueTensorType.
|
||||
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
|
||||
type ==
|
||||
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -203,14 +231,44 @@ OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenLenTOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes())
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||
tensorType.getSizes().size());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLenTOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
|
||||
// `len([1,1,1])` -> `3`
|
||||
if (auto buildList = getOperand().getDefiningOp<Basicpy::BuildListOp>()) {
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||
buildList.getNumOperands());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
// `len(t.size())` -> `t.ndim`
|
||||
patterns.add(+[](AtenLenTOp op, PatternRewriter &rewriter) {
|
||||
auto buildList = op.getOperand().getDefiningOp<Basicpy::BuildListOp>();
|
||||
if (!buildList)
|
||||
return rewriter.notifyMatchFailure(op, "operand not basicpy.build_list");
|
||||
rewriter.replaceOpWithNewOp<::mlir::ConstantOp>(
|
||||
op, rewriter.getI64IntegerAttr(buildList.getNumOperands()));
|
||||
auto size = op.getOperand().getDefiningOp<AtenSizeOp>();
|
||||
if (!size)
|
||||
return rewriter.notifyMatchFailure(op, "operand not AtenSizeOp");
|
||||
// TODO: Normalize all the torch scalar integer types to consistently use
|
||||
// a `!torch.int` type so that this op and others can automatically infer
|
||||
// their type. An additional benefit is that there's already enough of a
|
||||
// semantic gap between Python ints (which tend to be arbitrary precision)
|
||||
// and Torch/et-al ints (fixed bit depth, usually 64), it would be nice to
|
||||
// preserve the fact that we are working on a !torch.int and not just a
|
||||
// thing that was prematurely pinned to an `i64`.
|
||||
rewriter.replaceOpWithNewOp<AtenDimOp>(op, rewriter.getI64Type(),
|
||||
size.getOperand());
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
@ -222,11 +280,11 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
||||
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return rewriter.notifyMatchFailure(op, "not a ranked tensor");
|
||||
auto type = op.getOperand().getType().dyn_cast<BaseTensorType>();
|
||||
if (!type || !type.areAllSizesKnown())
|
||||
return rewriter.notifyMatchFailure(op, "all sizes not known");
|
||||
SmallVector<Value> listElements;
|
||||
for (int64_t size : type.getShape()) {
|
||||
for (int64_t size : type.getSizes()) {
|
||||
listElements.push_back(rewriter.create<::mlir::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
||||
}
|
||||
|
@ -234,6 +292,137 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
op, Basicpy::ListType::get(rewriter.getContext()), listElements);
|
||||
return success();
|
||||
});
|
||||
// One-off pattern to erase if dead.
|
||||
// TODO: Use the effects infra to express the semantics of this op and enable
|
||||
// a centralized "erase if dead" canonicalization.
|
||||
// Specifically, we need to mark the op as only MemoryEffects::Allocate
|
||||
// so that `mlir::wouldOpBeTriviallyDead` does the right thing.
|
||||
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
||||
if (!op.use_empty())
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
TensorOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes,
|
||||
RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
|
||||
if (!attr)
|
||||
return failure();
|
||||
auto tensorType = attr.getType().cast<RankedTensorType>();
|
||||
inferredReturnTypes.push_back(NonValueTensorType::getFromShaped(tensorType));
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
|
||||
if (a.hasSizes() && b.hasSizes()) {
|
||||
if (failed(verifyCompatibleShape(a.getSizes(), b.getSizes())))
|
||||
return false;
|
||||
}
|
||||
if (a.hasDtype() && b.hasDtype()) {
|
||||
if (a.getDtype() != b.getDtype())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TensorOp::isCompatibleReturnTypes(TypeRange inferred, TypeRange actual) {
|
||||
if (!actual[0].isa<BaseTensorType>())
|
||||
return false;
|
||||
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(),
|
||||
actual[0].cast<BaseTensorType>());
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// TensorStaticInfoCast
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
||||
mlir::TypeRange outputs) {
|
||||
return areSizesAndDtypesCompatible(inputs[0].cast<BaseTensorType>(),
|
||||
outputs[0].cast<BaseTensorType>());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CopyTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyTensorOp op) {
|
||||
auto resultType = op.getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasSameSizesAndDtype(operandType)) {
|
||||
return op.emitError()
|
||||
<< "operand and result must have same sizes and dtype";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult CopyTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
// A copy between value semantic tensors is a no-op.
|
||||
if (getType().isa<ValueTensorType>() &&
|
||||
getOperand().getType().isa<ValueTensorType>()) {
|
||||
return getOperand();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
// y = torch.copy.tensor(hasOneUse@torch.copy.tensor(x)) -> x
|
||||
// Only safe when `y` and `x` have value semantics.
|
||||
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
|
||||
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
|
||||
if (!otherCopy)
|
||||
return failure();
|
||||
if (otherCopy.getOperand().getType().isa<ValueTensorType>() &&
|
||||
op.getResult().getType().isa<ValueTensorType>() &&
|
||||
op.getOperand().hasOneUse()) {
|
||||
rewriter.replaceOp(op, {otherCopy.getOperand()});
|
||||
// TODO: Implement MemoryEffectOpInterface to handle the value/non-value
|
||||
// cases precisely. In this case, we specifically know that `otherCopy`
|
||||
// is dead so eagerly clean it up.
|
||||
rewriter.eraseOp(otherCopy);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToBuiltinTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto resultType =
|
||||
operands[0].getType().cast<ValueTensorType>().toBuiltinTensor();
|
||||
if (!resultType)
|
||||
return failure();
|
||||
inferredReturnTypes.push_back(resultType);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FromBuiltinTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult FromBuiltinTensorOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(
|
||||
ValueTensorType::getFromShaped(operands[0].getType().cast<TensorType>()));
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Torch/IR/TorchUtils.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
void mlir::NPCOMP::Torch::setupValueTensorToBuiltinTensorConversion(
|
||||
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||
target.addLegalOp<Torch::ToBuiltinTensorOp, Torch::FromBuiltinTensorOp>();
|
||||
typeConverter.addConversion(
|
||||
[](Torch::ValueTensorType type) -> Optional<Type> {
|
||||
return type.toBuiltinTensor();
|
||||
});
|
||||
typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type,
|
||||
ValueRange inputs,
|
||||
Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<BaseTensorType>());
|
||||
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
|
||||
});
|
||||
auto sourceMaterialization = [](OpBuilder &builder, ValueTensorType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<TensorType>());
|
||||
return builder.create<FromBuiltinTensorOp>(loc, inputs[0]);
|
||||
};
|
||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||
typeConverter.addArgumentMaterialization(sourceMaterialization);
|
||||
}
|
|
@ -16,8 +16,6 @@
|
|||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
|
@ -51,15 +49,18 @@ public:
|
|||
// The incoporation of the torch.type_bound arg attr is context-dependent.
|
||||
|
||||
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
||||
if (auto ndarray = type.value().dyn_cast<Numpy::NdArrayType>()) {
|
||||
if (type.value().isa<NonValueTensorType>()) {
|
||||
auto typeBoundAttr =
|
||||
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
|
||||
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
|
||||
if (!bound.isa<ValueTensorType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
func, "unimplemented: preserving aliasing for non-value-semantic "
|
||||
"type bounds");
|
||||
conversion.addInputs(type.index(), typeBoundAttr
|
||||
? typeBoundAttr.getValue()
|
||||
: type.value());
|
||||
continue;
|
||||
// type is attached to ndarray type.
|
||||
// TODO: check if more specific?
|
||||
} else if (auto none = type.value().dyn_cast<Basicpy::NoneType>()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -108,9 +109,15 @@ public:
|
|||
continue;
|
||||
auto it = typeBoundMap.find({call.callee(), operand.index()});
|
||||
if (it != typeBoundMap.end()) {
|
||||
newOperands.push_back(rewriter.create<Numpy::StaticInfoCastOp>(
|
||||
call.getLoc(), it->second, operand.value()));
|
||||
if (auto valueTensorType = it->second.dyn_cast<ValueTensorType>()) {
|
||||
newOperands.push_back(copyTensorToType(
|
||||
rewriter, call->getLoc(), valueTensorType, operand.value()));
|
||||
continue;
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
call, "unimplemented: preserving aliasing for non-value-semantic "
|
||||
"type bounds");
|
||||
}
|
||||
}
|
||||
newOperands.push_back(operand.value());
|
||||
}
|
||||
|
@ -172,11 +179,11 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
});
|
||||
|
||||
typeConverter.addArgumentMaterialization(
|
||||
[](OpBuilder &builder, Numpy::NdArrayType type, ValueRange inputs,
|
||||
[](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs,
|
||||
Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<Numpy::NdArrayType>());
|
||||
return builder.create<Numpy::StaticInfoCastOp>(loc, type, inputs[0]);
|
||||
assert(inputs[0].getType().isa<BaseTensorType>());
|
||||
return copyTensorToType(builder, loc, type, inputs[0]);
|
||||
});
|
||||
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);
|
||||
patterns.add<AdjustCallingConventionForCall>(typeConverter, context,
|
||||
|
@ -211,7 +218,8 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
|
||||
return !opsInOriginalProgram.contains(op.getOperation());
|
||||
});
|
||||
target.addLegalOp<Numpy::StaticInfoCastOp>();
|
||||
target.addLegalOp<CopyTensorOp>();
|
||||
target.addLegalOp<TensorStaticInfoCastOp>();
|
||||
target.addLegalOp<Basicpy::SingletonOp>();
|
||||
// We don't know how to rewrite it, so mark it as illegal.
|
||||
target.addIllegalOp<CallIndirectOp>();
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
//===- BuiltinTensorize.cpp --------------------------------------*- C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchUtils.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
namespace {
|
||||
struct FuncBuiltinTensorizePass
|
||||
: public FuncBuiltinTensorizeBase<FuncBuiltinTensorizePass> {
|
||||
using FuncBuiltinTensorizeBase<
|
||||
FuncBuiltinTensorizePass>::FuncBuiltinTensorizeBase;
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
||||
|
||||
populateFuncOpTypeConversionPattern(patterns, typeConverter);
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getType()) &&
|
||||
typeConverter.isLegal(&op.getBody());
|
||||
});
|
||||
populateCallOpTypeConversionPattern(patterns, typeConverter);
|
||||
target.addDynamicallyLegalOp<CallOp>(
|
||||
[&](CallOp op) { return typeConverter.isLegal(op); });
|
||||
|
||||
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
|
||||
populateReturnOpTypeConversionPattern(patterns, typeConverter);
|
||||
target.addLegalOp<ModuleOp>();
|
||||
|
||||
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
|
||||
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
|
||||
isLegalForBranchOpInterfaceTypeConversionPattern(op,
|
||||
typeConverter) ||
|
||||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
|
||||
});
|
||||
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Torch::createFuncBuiltinTensorizePass() {
|
||||
return std::make_unique<FuncBuiltinTensorizePass>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// In a finalizing conversion, we know that all `!torch.vtensor` have been
|
||||
// converted to `tensor`, thus, this op becomes an identity.
|
||||
class FinalizeToBuiltinTensorOp
|
||||
: public OpConversionPattern<ToBuiltinTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ToBuiltinTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, operands[0]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// In a finalizing conversion, we know that all `!torch.vtensor` have been
|
||||
// converted to `tensor`, thus, this op becomes an identity.
|
||||
class FinalizeFromBuiltinTensorOp
|
||||
: public OpConversionPattern<FromBuiltinTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(FromBuiltinTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, operands[0]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
struct FinalizingBuiltinTensorizePass
|
||||
: public FinalizingBuiltinTensorizeBase<FinalizingBuiltinTensorizePass> {
|
||||
using FinalizingBuiltinTensorizeBase<
|
||||
FinalizingBuiltinTensorizePass>::FinalizingBuiltinTensorizeBase;
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
||||
target.addIllegalOp<ToBuiltinTensorOp, FromBuiltinTensorOp>();
|
||||
|
||||
patterns.add<FinalizeFromBuiltinTensorOp, FinalizeToBuiltinTensorOp>(
|
||||
typeConverter, context);
|
||||
|
||||
// If all result types are legal, and all block arguments are legal, then
|
||||
// all types in the program are legal.
|
||||
//
|
||||
// We also check that the operand types are legal to avoid creating invalid
|
||||
// IR. For example, this prevents the patterns from updating
|
||||
// the types of the operands to a return op without updating the enclosing
|
||||
// function.
|
||||
target.markUnknownOpDynamicallyLegal(
|
||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||
|
||||
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::Torch::createFinalizingBuiltinTensorizePass() {
|
||||
return std::make_unique<FinalizingBuiltinTensorizePass>();
|
||||
}
|
|
@ -1,10 +1,13 @@
|
|||
add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||
AdjustCallingConventions.cpp
|
||||
BuiltinTensorize.cpp
|
||||
Passes.cpp
|
||||
GlobalizeObjectGraph.cpp
|
||||
InlineGlobalSlots.cpp
|
||||
MaximizeValueSemantics.cpp
|
||||
PrepareForGlobalizeObjectGraph.cpp
|
||||
ReduceOpVariants.cpp
|
||||
RefinePublicReturn.cpp
|
||||
RefineTypes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
@ -21,7 +24,6 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
|||
MLIRPass
|
||||
NPCOMPTorchDialect
|
||||
NPCOMPBasicpyDialect
|
||||
NPCOMPNumpyPasses
|
||||
NPCOMPTorchToLinalg
|
||||
NPCOMPTCFToStd
|
||||
NPCOMPInterfaces
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- ArrayToTensor.cpp -----------------------------------------*- C++-*-===//
|
||||
//===- MaximizeValueSemantics.cpp --------------------------------*- C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -13,24 +13,24 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Numpy;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
namespace {
|
||||
|
||||
class ArrayToTensorPass : public NumpyArrayToTensorBase<ArrayToTensorPass> {
|
||||
class MaximizeValueSemanticsPass
|
||||
: public MaximizeValueSemanticsBase<MaximizeValueSemanticsPass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
auto func = getOperation();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
CopyToTensorOp::getCanonicalizationPatterns(patterns, context);
|
||||
StaticInfoCastOp::getCanonicalizationPatterns(patterns, context);
|
||||
CopyTensorOp::getCanonicalizationPatterns(patterns, context);
|
||||
TensorStaticInfoCastOp::getCanonicalizationPatterns(patterns, context);
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
@ -38,6 +38,6 @@ class ArrayToTensorPass : public NumpyArrayToTensorBase<ArrayToTensorPass> {
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::Numpy::createArrayToTensorPass() {
|
||||
return std::make_unique<ArrayToTensorPass>();
|
||||
mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass() {
|
||||
return std::make_unique<MaximizeValueSemanticsPass>();
|
||||
}
|
|
@ -13,7 +13,6 @@
|
|||
#include "npcomp/Backend/Common/Passes.h"
|
||||
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
|
@ -94,7 +93,7 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
|||
if (options.optimize) {
|
||||
// Inline global slots, which for most inference scenarios deletes them.
|
||||
// This also exposes more information to intraprocedural transformations
|
||||
// below like ArrayToTensor and RefineTypes.
|
||||
// below like MaximizeValueSemantics and RefineTypes.
|
||||
// OPT-ONLY: Don't rely on this pass to "lower" global slots by deleting.
|
||||
// Also don't rely on this pass to expose constants into the program to
|
||||
// simplify handling of "optional".
|
||||
|
@ -103,9 +102,6 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
|||
|
||||
// Reduce variants of ops to a smaller set of primitives.
|
||||
pm.addNestedPass<FuncOp>(createReduceOpVariantsPass());
|
||||
// Convert any operations on primitive types as soon as possible. Unlike
|
||||
// tensor compute ops, we don't need to wait for dtype/shape inference.
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
|
||||
|
@ -119,30 +115,36 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
|||
pm.addPass(createSymbolDCEPass());
|
||||
}
|
||||
|
||||
// Convert the bulk of the program to ranked tensors with known dtype.
|
||||
// This is the input to the backend layer that we are aiming for.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Lowering to ranked !torch.vtensors of known dtype.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// First, unilaterally convert public functions to tensor.
|
||||
// The way this pass is currently written, this implies that
|
||||
// as pipeline authors, we are restricting our users to not be able to see
|
||||
// updates to "out params" on their public functions.
|
||||
// This is deemed ok for now.
|
||||
pm.addPass(Numpy::createPublicFunctionsToTensorPass());
|
||||
// Convert the bulk of non-ABI-visible arrays to tensors.
|
||||
pm.addNestedPass<FuncOp>(Numpy::createArrayToTensorPass());
|
||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
// Do shape and dtype refinement.
|
||||
// We could do it sooner, but the pass currently doesn't have transfer
|
||||
// functions for array ops.
|
||||
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||
pm.addPass(Numpy::createRefinePublicReturnPass());
|
||||
// Clean up a few stray array/tensor conversion remnants.
|
||||
pm.addNestedPass<FuncOp>(Numpy::createArrayToTensorPass());
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
// Clean up a few stray conversion remnants.
|
||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Lowering ops and the !torch.vtensor type.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Convert any operations on primitive types. These need at least basic dtype
|
||||
// inference, otherwise we cannot interop with builtin tensors.
|
||||
// Run this before this canonicalizer as this will expose optimization
|
||||
// opportunities thanks to folders on std ops that we don't have on the
|
||||
// corresponding torch ops.
|
||||
// TODO: Improve torch op canonicalizations.
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// RefineTypes has exposed new type information that allows folding away
|
||||
// more stuff. OPT-ONLY: Right now we rely on this to eliminate certain
|
||||
// more stuff.
|
||||
// OPT-ONLY: Right now we rely on this to eliminate certain
|
||||
// branches that guard unreachable code that backends can't handle yet, such
|
||||
// as lists, RaiseException, unimplemented aten ops, and
|
||||
// only-used-in-training operations on `torch.global_slot`'s.
|
||||
|
@ -152,6 +154,15 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
|||
// Lower to linalg + guards which is the input to codegen backends.
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced in our linalg lowering.
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
|
||||
// Finish the type conversion from !torch.vtensor to the builtin tensor type.
|
||||
pm.addPass(createFuncBuiltinTensorizePass());
|
||||
pm.addNestedPass<FuncOp>(createFinalizingBuiltinTensorizePass());
|
||||
|
||||
// Verify that we have lowered to the form that backends expect.
|
||||
// This fails compilation (signalPassFailure) if the IR is not in the
|
||||
// correct form.
|
||||
|
|
|
@ -9,8 +9,6 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -35,23 +33,23 @@ public:
|
|||
// Convert all operands.
|
||||
SmallVector<Value> newOperands;
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
auto ndArrayType =
|
||||
opOperand.get().getType().dyn_cast<Numpy::NdArrayType>();
|
||||
if (!ndArrayType)
|
||||
auto tensorType =
|
||||
opOperand.get().getType().dyn_cast<NonValueTensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
opOperand.set(rewriter.create<Numpy::CopyToTensorOp>(
|
||||
op->getLoc(), ndArrayType.toTensorType(), opOperand.get()));
|
||||
opOperand.set(rewriter.create<CopyTensorOp>(
|
||||
op->getLoc(), tensorType.getWithValueSemantics(), opOperand.get()));
|
||||
}
|
||||
// Convert all results.
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
for (Value result : op->getResults()) {
|
||||
auto ndArrayType = result.getType().dyn_cast<Numpy::NdArrayType>();
|
||||
if (!ndArrayType)
|
||||
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
auto createArray = rewriter.create<Numpy::CreateArrayFromTensorOp>(
|
||||
auto createArray = rewriter.create<CopyTensorOp>(
|
||||
op->getLoc(), result.getType(), result);
|
||||
result.replaceAllUsesExcept(createArray, createArray);
|
||||
result.setType(ndArrayType.toTensorType());
|
||||
result.setType(tensorType.getWithValueSemantics());
|
||||
}
|
||||
});
|
||||
return success();
|
||||
|
@ -87,12 +85,13 @@ public:
|
|||
"Torch JIT operators shouldn't have regions or successors");
|
||||
|
||||
Operation *newOp = rewriter.createOperation(state);
|
||||
auto tensor = rewriter.create<Numpy::CopyToTensorOp>(
|
||||
op->getLoc(),
|
||||
newOp->getResult(0).getType().cast<Numpy::NdArrayType>().toTensorType(),
|
||||
auto tensor = rewriter.create<CopyTensorOp>(op->getLoc(),
|
||||
newOp->getResult(0)
|
||||
.getType()
|
||||
.cast<NonValueTensorType>()
|
||||
.getWithValueSemantics(),
|
||||
newOp->getResult(0));
|
||||
rewriter.create<Numpy::OverwriteArrayOp>(op->getLoc(), tensor,
|
||||
op->getOperand(0));
|
||||
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
|
||||
return success();
|
||||
|
@ -111,9 +110,16 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
ConversionTarget target(*context);
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
auto isNdArray = [](Type t) { return t.isa<Numpy::NdArrayType>(); };
|
||||
return llvm::none_of(op->getOperandTypes(), isNdArray) &&
|
||||
llvm::none_of(op->getResultTypes(), isNdArray);
|
||||
auto hasValueSemantics = [](Type t) {
|
||||
// TODO: Make this an allowlist based on a closed torch dialect
|
||||
// type system.
|
||||
if (auto tensorType = t.dyn_cast<NonValueTensorType>()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
return llvm::all_of(op->getOperandTypes(), hasValueSemantics) &&
|
||||
llvm::all_of(op->getResultTypes(), hasValueSemantics);
|
||||
}
|
||||
if (op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>()) {
|
||||
return false;
|
||||
|
|
|
@ -11,17 +11,17 @@
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Numpy;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
namespace {
|
||||
|
||||
class RefinePublicReturnPass
|
||||
: public NumpyRefinePublicReturnBase<RefinePublicReturnPass> {
|
||||
: public RefinePublicReturnBase<RefinePublicReturnPass> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
module.walk([&](FuncOp func) {
|
||||
|
@ -58,11 +58,20 @@ class RefinePublicReturnPass
|
|||
// TensorStaticInfoCastOp then the pre-casted operand, which is presumed to
|
||||
// have a more precise type.
|
||||
SmallVector<Value> newOperands;
|
||||
OpBuilder builder(returnOp);
|
||||
for (auto operand : returnOp.getOperands()) {
|
||||
Value newOperand;
|
||||
if (auto cast = operand.getDefiningOp<TensorStaticInfoCastOp>()) {
|
||||
newOperands.push_back(cast.getOperand());
|
||||
newOperand = cast.getOperand();
|
||||
} else {
|
||||
newOperands.push_back(operand);
|
||||
newOperand = operand;
|
||||
}
|
||||
if (auto tensorType = newOperand.getType().dyn_cast<BaseTensorType>()) {
|
||||
newOperands.push_back(
|
||||
copyTensorToType(builder, returnOp->getLoc(),
|
||||
tensorType.getWithValueSemantics(), newOperand));
|
||||
} else {
|
||||
newOperands.push_back(newOperand);
|
||||
}
|
||||
}
|
||||
returnOp->setOperands(newOperands);
|
||||
|
@ -77,6 +86,6 @@ class RefinePublicReturnPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Numpy::createRefinePublicReturnPass() {
|
||||
mlir::NPCOMP::Torch::createRefinePublicReturnPass() {
|
||||
return std::make_unique<RefinePublicReturnPass>();
|
||||
}
|
|
@ -18,8 +18,6 @@
|
|||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
|
@ -32,16 +30,14 @@ using namespace mlir::NPCOMP::Torch;
|
|||
// Analysis.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
constexpr int64_t kUnknownSize = -1;
|
||||
|
||||
static Type joinElementTypes(Type lhs, Type rhs) {
|
||||
if (lhs.isa<Numpy::AnyDtypeType>())
|
||||
if (!lhs)
|
||||
return rhs;
|
||||
if (rhs.isa<Numpy::AnyDtypeType>())
|
||||
if (!rhs)
|
||||
return lhs;
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
return Numpy::AnyDtypeType::get(lhs.getContext());
|
||||
return Type();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -56,28 +52,20 @@ namespace {
|
|||
// This class could also be called "dataflow facts", "lattice value", etc.
|
||||
struct ValueKnowledge {
|
||||
ValueKnowledge() = delete;
|
||||
// We enforce that `elementType` is always a valid type (possibly
|
||||
// !numpy.any_dtype), and `sizes` is empty unless `hasRank`.
|
||||
// So default constructing is prohibited.
|
||||
ValueKnowledge(bool hasRank, std::vector<int64_t> sizes, Type elementType)
|
||||
: hasRank(hasRank), sizes(sizes), elementType(elementType) {
|
||||
assert(elementType != nullptr);
|
||||
assert(sizes.size() == 0 || hasRank);
|
||||
ValueKnowledge(bool hasSizes, std::vector<int64_t> sizes, Type dtype)
|
||||
: hasSizes(hasSizes), sizes(sizes), dtype(dtype) {
|
||||
assert(sizes.size() == 0 || hasSizes);
|
||||
}
|
||||
|
||||
// Get the static knowledge intrinsic to `type`.
|
||||
static ValueKnowledge getKnowledgeFromType(Type type) {
|
||||
ValueKnowledge result = getPessimisticValueState(type.getContext());
|
||||
if (auto tensorType = type.dyn_cast<TensorType>()) {
|
||||
if (tensorType.hasRank()) {
|
||||
result.hasRank = true;
|
||||
result.sizes = tensorType.getShape().vec();
|
||||
if (auto tensorType = type.dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes()) {
|
||||
result.hasSizes = true;
|
||||
result.sizes = tensorType.getSizes().vec();
|
||||
}
|
||||
result.elementType = tensorType.getElementType();
|
||||
return result;
|
||||
}
|
||||
if (auto ndArrayType = type.dyn_cast<Numpy::NdArrayType>()) {
|
||||
return getKnowledgeFromType(ndArrayType.toTensorType());
|
||||
result.dtype = tensorType.getOptionalDtype();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -85,7 +73,7 @@ struct ValueKnowledge {
|
|||
// Return a pessimistic/conservative value state without assuming any knowlege
|
||||
// about the IR.
|
||||
static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
|
||||
return ValueKnowledge(false, {}, Numpy::AnyDtypeType::get(context));
|
||||
return ValueKnowledge(false, {}, Type());
|
||||
}
|
||||
// Return a pessimistic/conservative value state only using knowlege already
|
||||
// recorded in the IR.
|
||||
|
@ -94,8 +82,8 @@ struct ValueKnowledge {
|
|||
}
|
||||
|
||||
bool operator==(const ValueKnowledge &rhs) const {
|
||||
return std::make_tuple(hasRank, sizes, elementType) ==
|
||||
std::make_tuple(rhs.hasRank, rhs.sizes, rhs.elementType);
|
||||
return std::make_tuple(hasSizes, sizes, dtype) ==
|
||||
std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype);
|
||||
}
|
||||
|
||||
// Given two pieces of static knowledge, calculate conservatively the
|
||||
|
@ -105,18 +93,17 @@ struct ValueKnowledge {
|
|||
// Mental model: All conditions are checking how to change from the safe "no
|
||||
// knowledge" default-initialized state to a state with more knowledge
|
||||
// consistent with lhs and rhs.
|
||||
ValueKnowledge result =
|
||||
getPessimisticValueState(lhs.elementType.getContext());
|
||||
ValueKnowledge result = getPessimisticValueState(nullptr);
|
||||
|
||||
if (lhs.hasRank && !rhs.hasRank) {
|
||||
result.hasRank = true;
|
||||
if (lhs.hasSizes && !rhs.hasSizes) {
|
||||
result.hasSizes = true;
|
||||
result.sizes = lhs.sizes;
|
||||
} else if (!lhs.hasRank && rhs.hasRank) {
|
||||
result.hasRank = true;
|
||||
} else if (!lhs.hasSizes && rhs.hasSizes) {
|
||||
result.hasSizes = true;
|
||||
result.sizes = rhs.sizes;
|
||||
} else if (lhs.hasRank && rhs.hasRank &&
|
||||
} else if (lhs.hasSizes && rhs.hasSizes &&
|
||||
lhs.sizes.size() == rhs.sizes.size()) {
|
||||
result.hasRank = true;
|
||||
result.hasSizes = true;
|
||||
result.sizes.resize(lhs.sizes.size(), kUnknownSize);
|
||||
for (int i = 0, e = result.sizes.size(); i != e; i++) {
|
||||
int64_t lhsSize = lhs.sizes[i];
|
||||
|
@ -132,29 +119,21 @@ struct ValueKnowledge {
|
|||
}
|
||||
}
|
||||
|
||||
result.elementType = joinElementTypes(lhs.elementType, rhs.elementType);
|
||||
result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Whether the Value is known to have a rank.
|
||||
bool hasRank;
|
||||
// If `hasRank` the sizes along each rank. Unknown sizes are represented as
|
||||
// Whether the Value is known to have a list of sizes.
|
||||
bool hasSizes;
|
||||
// If `hasSizes`, the sizes along each rank. Unknown sizes are represented as
|
||||
// `kUnknownSize`.
|
||||
std::vector<int64_t> sizes;
|
||||
// The element type of a shaped type.
|
||||
// This is equal to !numpy.any_dtype if it is not a concrete type.
|
||||
Type elementType;
|
||||
// The dtype of a tensor.
|
||||
// This is equal to nullptr if we don't know that it is a specific concrete
|
||||
// type.
|
||||
Type dtype;
|
||||
};
|
||||
|
||||
// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge
|
||||
// &knowledge) {
|
||||
// os << "hasRank = " << knowledge.hasRank << ", sizes = [";
|
||||
// llvm::interleaveComma(knowledge.sizes, os);
|
||||
// os << "]"
|
||||
// << ", elementType = " << knowledge.elementType;
|
||||
// return os;
|
||||
// }
|
||||
|
||||
// Forward intraprocedural dataflow for type information.
|
||||
class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
|
||||
public:
|
||||
|
@ -165,8 +144,7 @@ public:
|
|||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
||||
if (isa<Numpy::TensorStaticInfoCastOp, Numpy::CopyToTensorOp,
|
||||
Numpy::CreateArrayFromTensorOp, AtenTanhOp, AtenBatchNormOp,
|
||||
if (isa<TensorStaticInfoCastOp, CopyTensorOp, AtenTanhOp, AtenBatchNormOp,
|
||||
AtenReluOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
@ -175,7 +153,7 @@ public:
|
|||
auto &rhs = operands[1]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.hasRank = true;
|
||||
knowledge.hasSizes = true;
|
||||
// WARNING: We could be more precise here by calculating the output
|
||||
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
|
||||
// at this stage in the compiler because we don't really have many static
|
||||
|
@ -186,8 +164,8 @@ public:
|
|||
//
|
||||
// Example: Suppose a user program calls `aten.mm` with two rank-0
|
||||
// operands. The program emits an error when invoked, but when running
|
||||
// this pass, we will (correctly!) infer `lhs.hasRank && lhs.sizes.size()
|
||||
// == 0 && rhs.hasRank && rhs.sizes.size() == 0` -- it's not safe to
|
||||
// this pass, we will (correctly!) infer `lhs.hasSizes && lhs.sizes.size()
|
||||
// == 0 && rhs.hasSizes && rhs.sizes.size() == 0` -- it's not safe to
|
||||
// access `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer
|
||||
// function, it's not as simple as taking `lhs.sizes[0]` and
|
||||
// `rhs.sizes[1]`, as both of those might read out of bounds of the array.
|
||||
|
@ -201,51 +179,48 @@ public:
|
|||
// TODO: Investigate promotion rules if element types mismatch.
|
||||
// This is conservatively correct, assuming that if both element types are
|
||||
// the same, then the result is of that same element type.
|
||||
knowledge.elementType =
|
||||
joinElementTypes(lhs.elementType, rhs.elementType);
|
||||
knowledge.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<AtenLinearOp>(op)) {
|
||||
// The output shape is the input shape with the last dimension changed
|
||||
// to the weight's output dimension.
|
||||
auto knowledge = operands[0]->getValue();
|
||||
if (knowledge.hasRank && knowledge.sizes.size() > 0)
|
||||
if (knowledge.hasSizes && knowledge.sizes.size() > 0)
|
||||
knowledge.sizes[knowledge.sizes.size() - 1] = kUnknownSize;
|
||||
// TODO: Handle case of bias being None gracefully. Requires a lattice
|
||||
// that tracks "None" (torch.optional). See also
|
||||
// DerefineOp::getCanonicalizationPatterns for more refinement that needs
|
||||
// to be done in this pass.
|
||||
knowledge.elementType = joinElementTypes(
|
||||
knowledge.elementType,
|
||||
joinElementTypes(operands[1]->getValue().elementType,
|
||||
operands[2]->getValue().elementType));
|
||||
knowledge.dtype = joinElementTypes(
|
||||
knowledge.dtype, joinElementTypes(operands[1]->getValue().dtype,
|
||||
operands[2]->getValue().dtype));
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<AtenConv2dOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.hasRank = true;
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
// Running some experiments in PyTorch, the bias doesn't seem to
|
||||
// contribute to the final element type.
|
||||
knowledge.elementType =
|
||||
joinElementTypes(operands[0]->getValue().elementType,
|
||||
operands[1]->getValue().elementType);
|
||||
knowledge.dtype = joinElementTypes(operands[0]->getValue().dtype,
|
||||
operands[1]->getValue().dtype);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<AtenMaxPool2dOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.hasRank = true;
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
knowledge.elementType = operands[0]->getValue().elementType;
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<AtenAdaptiveAvgPool2dOp>(op)) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
if (input.hasRank) {
|
||||
knowledge.hasRank = true;
|
||||
if (input.hasSizes) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(input.sizes.size(), kUnknownSize);
|
||||
}
|
||||
knowledge.elementType = input.elementType;
|
||||
knowledge.dtype = input.dtype;
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (isa<AtenAddTensorOp>(op)) {
|
||||
// This is a general binary broadcasting shape transfer function.
|
||||
|
@ -257,25 +232,24 @@ public:
|
|||
auto rhs = operands[1]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
if (lhs.hasRank && rhs.hasRank) {
|
||||
knowledge.hasRank = true;
|
||||
if (lhs.hasSizes && rhs.hasSizes) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
|
||||
kUnknownSize);
|
||||
}
|
||||
knowledge.elementType =
|
||||
joinElementTypes(lhs.elementType, rhs.elementType);
|
||||
knowledge.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
} else if (auto flatten = dyn_cast<AtenFlattenUsingIntsOp>(op)) {
|
||||
APInt startDimAP, endDimAP;
|
||||
auto operand = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.elementType = operand.elementType;
|
||||
if (operand.hasRank && operand.sizes.size() == 0) {
|
||||
knowledge.dtype = operand.dtype;
|
||||
if (operand.hasSizes && operand.sizes.size() == 0) {
|
||||
// Rank 0 is special and flattens to rank 1.
|
||||
knowledge.hasRank = true;
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.push_back(kUnknownSize);
|
||||
} else if (operand.hasRank &&
|
||||
} else if (operand.hasSizes &&
|
||||
matchPattern(flatten.start_dim(),
|
||||
m_ConstantInt(&startDimAP)) &&
|
||||
matchPattern(flatten.end_dim(), m_ConstantInt(&endDimAP))) {
|
||||
|
@ -289,7 +263,7 @@ public:
|
|||
// Careful: dimension numbers might be out of bounds.
|
||||
if (0 <= startDim && startDim <= (inputRank - 1) && 0 <= endDim &&
|
||||
endDim <= (inputRank - 1) && startDim <= endDim) {
|
||||
knowledge.hasRank = true;
|
||||
knowledge.hasSizes = true;
|
||||
for (auto i = 0; i < startDim; i++)
|
||||
knowledge.sizes.push_back(operand.sizes[i]);
|
||||
knowledge.sizes.push_back(kUnknownSize);
|
||||
|
@ -310,43 +284,35 @@ public:
|
|||
// Transforms.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Get the most refined TensorType compatible with ValueKnowledge.
|
||||
static TensorType
|
||||
getTensorTypeFromKnowledge(MLIRContext *context,
|
||||
LatticeElement<ValueKnowledge> *knowledge) {
|
||||
if (!knowledge)
|
||||
return UnrankedTensorType::get(Numpy::AnyDtypeType::get(context));
|
||||
|
||||
const ValueKnowledge &value = knowledge->getValue();
|
||||
if (!value.hasRank)
|
||||
return UnrankedTensorType::get(value.elementType);
|
||||
return RankedTensorType::get(value.sizes, value.elementType);
|
||||
}
|
||||
|
||||
// Get the most refined Numpy::NdArrayType compatible with ValueKnowledge.
|
||||
static Numpy::NdArrayType
|
||||
getNdArrayTypeFromKnowledge(MLIRContext *context,
|
||||
LatticeElement<ValueKnowledge> *knowledge) {
|
||||
if (!knowledge)
|
||||
return Numpy::NdArrayType::get(Numpy::AnyDtypeType::get(context));
|
||||
|
||||
const ValueKnowledge &value = knowledge->getValue();
|
||||
if (!value.hasRank)
|
||||
return Numpy::NdArrayType::get(value.elementType);
|
||||
return Numpy::NdArrayType::get(value.elementType,
|
||||
llvm::makeArrayRef(value.sizes));
|
||||
}
|
||||
|
||||
// Get a the most refined type compatible with ValueKnowledge, or null if that
|
||||
// is not possible.
|
||||
static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
||||
if (v.getType().isa<TensorType>())
|
||||
return getTensorTypeFromKnowledge(v.getContext(),
|
||||
analyzer.lookupLatticeElement(v));
|
||||
if (v.getType().isa<Numpy::NdArrayType>())
|
||||
return getNdArrayTypeFromKnowledge(v.getContext(),
|
||||
analyzer.lookupLatticeElement(v));
|
||||
if (auto tensorType = v.getType().dyn_cast<BaseTensorType>()) {
|
||||
LatticeElement<ValueKnowledge> *latticeElement =
|
||||
analyzer.lookupLatticeElement(v);
|
||||
if (!latticeElement)
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
return tensorType.getWithSizesAndDtype(
|
||||
knowledge.hasSizes ? llvm::makeArrayRef(knowledge.sizes)
|
||||
: Optional<ArrayRef<int64_t>>(),
|
||||
knowledge.dtype);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Return true if we can safely change the operands or results of `op`.
|
||||
//
|
||||
// The most trivial case is when the op has the AllowsTypeRefinement trait,
|
||||
// which allows arbitrary refinements. But some other cases are safe too,
|
||||
// such as when an op has two types that are coupled, but we know that our
|
||||
// analysis and updating logic will correctly maintain the invariants of the op.
|
||||
// The `torch.copy.tensor` is an example of the latter case, since its
|
||||
// operand and result types must have the same shape and dtype -- we know
|
||||
// that our transfer functions and updating logic will do the right thing
|
||||
// for that op.
|
||||
static bool allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(Operation *op) {
|
||||
return allowsTypeRefinement(op) || isa<CopyTensorOp>(op);
|
||||
}
|
||||
|
||||
void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
||||
|
@ -370,15 +336,10 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
|||
// types).
|
||||
std::function<Value(Location, Type, Value)> createStaticInfoCast;
|
||||
OpBuilder b(op->getBlock(), std::next(op->getIterator()));
|
||||
if (originalType.isa<TensorType>()) {
|
||||
if (originalType.isa<BaseTensorType>()) {
|
||||
createStaticInfoCast = [&](Location loc, Type newType,
|
||||
Value v) -> Value {
|
||||
return b.create<Numpy::TensorStaticInfoCastOp>(loc, newType, v);
|
||||
};
|
||||
} else if (originalType.isa<Numpy::NdArrayType>()) {
|
||||
createStaticInfoCast = [&](Location loc, Type newType,
|
||||
Value v) -> Value {
|
||||
return b.create<Numpy::StaticInfoCastOp>(loc, newType, v);
|
||||
return b.create<TensorStaticInfoCastOp>(loc, newType, v);
|
||||
};
|
||||
}
|
||||
if (createStaticInfoCast) {
|
||||
|
@ -392,7 +353,7 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
|||
// Always make sure that the new static information is reflected in the
|
||||
// IR, either by updating the type in place, or inserting a static info
|
||||
// cast.
|
||||
if (allowsTypeRefinement(op)) {
|
||||
if (allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(op)) {
|
||||
newTypedValue = v;
|
||||
v.setType(refinedType);
|
||||
} else {
|
||||
|
@ -402,7 +363,8 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
|||
Value oldTypedValue;
|
||||
for (OpOperand *use : originalUses) {
|
||||
// If the use can be updated to the new type directly, do it!
|
||||
if (allowsTypeRefinement(use->getOwner())) {
|
||||
if (allowsTypeRefinementOrWillBeOtherwiseSafelyRefined(
|
||||
use->getOwner())) {
|
||||
use->set(newTypedValue);
|
||||
continue;
|
||||
}
|
||||
|
@ -421,9 +383,6 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
|||
|
||||
namespace {
|
||||
class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<Numpy::NumpyDialect>();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
TypeAnalyzer analyzer(&getContext());
|
||||
|
|
|
@ -1,62 +1,72 @@
|
|||
// RUN: npcomp-opt <%s -convert-torch-to-linalg | FileCheck %s
|
||||
// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.mm$basic(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x2xf32> {
|
||||
// CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
|
||||
// CHECK: %[[LHS:.*]] = torch.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[RHS:.*]] = torch.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[LHS_DIM_0:.*]] = memref.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_DIM_1:.*]] = memref.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[RHS_DIM_0:.*]] = memref.dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[RHS_DIM_1:.*]] = memref.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[EQ:.*]] = cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index
|
||||
// CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm"
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[LHS_DIM_0]], %[[RHS_DIM_1]]] : tensor<?x?xf32>
|
||||
// CHECK: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[CF0]]) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
|
||||
// CHECK: return %[[CASTED]] : tensor<?x2xf32>
|
||||
func @torch.aten.mm$basic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x2xf32> {
|
||||
%0 = torch.aten.mm %arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x2xf32>
|
||||
return %0 : tensor<?x2xf32>
|
||||
}
|
||||
|
||||
// If the operands are missing dtype, we cannot lower it.
|
||||
// CHECK-LABEL: func @torch.aten.mm$no_convert$missing_dtype
|
||||
func @torch.aten.mm$no_convert$missing_dtype(%arg0: tensor<*x!numpy.any_dtype>, %arg1: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK-NEXT: torch.aten.mm
|
||||
%0 = torch.aten.mm %arg0, %arg1 : tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype> -> tensor<*x!numpy.any_dtype>
|
||||
return %0 : tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// Correctly handle the case that operands are statically the wrong rank
|
||||
// (rank 1 vs rank 2 expected for matmul.)
|
||||
// CHECK-LABEL: func @torch.aten.mm$no_convert$wrong_rank
|
||||
func @torch.aten.mm$no_convert$wrong_rank(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-NEXT: torch.aten.mm
|
||||
%0 = torch.aten.mm %arg0, %arg1 : tensor<?xf32>, tensor<?xf32> -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// If the result is missing dtype, we cannot lower it.
|
||||
// CHECK-LABEL: func @torch.aten.mm$no_convert$result_missing_dtype
|
||||
func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK-NEXT: torch.aten.mm
|
||||
%0 = torch.aten.mm %arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<*x!numpy.any_dtype>
|
||||
return %0 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK: %[[RESULT_VTENSOR:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32>
|
||||
// CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32>
|
||||
func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
|
||||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32>
|
||||
return %0 : !torch.vtensor<[?,2],f32>
|
||||
}
|
||||
|
||||
// Unary op example.
|
||||
// CHECK-LABEL: func @torch.aten.tanh(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-SAME: %[[ARG_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[ARG:.*]] = torch.to_builtin_tensor %[[ARG_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[ARG]] : tensor<?x?xf32>) outs(%[[ARG]] : tensor<?x?xf32>) {
|
||||
// CHECK: ^bb0(%[[BBARG:.*]]: f32, %{{.*}}: f32):
|
||||
// CHECK: %[[YIELDED:.*]] = math.tanh %[[BBARG]] : f32
|
||||
// CHECK: linalg.yield %[[YIELDED]] : f32
|
||||
// CHECK: } -> tensor<?x?xf32>
|
||||
// CHECK: return %[[RESULT:.*]] : tensor<?x?xf32>
|
||||
func @torch.aten.tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = torch.aten.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_VTENSOR:.*]] = torch.from_builtin_tensor %[[RESULT]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT_VTENSOR:.*]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.tanh(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// If the operands are missing dtype, we cannot lower it.
|
||||
func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||
// expected-error@+1 {{failed to legalize}}
|
||||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Correctly handle the case that operands are statically the wrong rank
|
||||
// (rank 1 vs rank 2 expected for matmul.)
|
||||
func @torch.aten.mm$no_convert$wrong_rank(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// expected-error@+1 {{failed to legalize}}
|
||||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// If the result is missing dtype, we cannot lower it.
|
||||
func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
|
||||
// expected-error@+1 {{failed to legalize}}
|
||||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
// RUN: npcomp-opt <%s -convert-torch-to-std | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @aten.dim(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!numpy.any_dtype>) -> i64 {
|
||||
// CHECK: %[[RANK_INDEX:.*]] = rank %[[ARG0]] : tensor<*x!numpy.any_dtype>
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> i64 {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG0]] : !torch.vtensor<*,f32> -> tensor<*xf32>
|
||||
// CHECK: %[[RANK_INDEX:.*]] = rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
|
||||
// CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK_INDEX]] : index to i64
|
||||
// CHECK: return %[[RANK_I64]] : i64
|
||||
func @aten.dim(%arg0: tensor<*x!numpy.any_dtype>) -> i64 {
|
||||
%0 = torch.aten.dim %arg0 : tensor<*x!numpy.any_dtype> -> i64
|
||||
func @aten.dim(%arg0: !torch.vtensor<*,f32>) -> i64 {
|
||||
%0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> i64
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
|
@ -31,3 +32,22 @@ func @torch.aten.gt.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
|
|||
%0 = torch.aten.gt.int %arg0, %arg1 : i64, i64 -> !basicpy.BoolType
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor$value() -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32>
|
||||
func @torch.tensor$value() -> !torch.vtensor<[],f32> {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor$nonval() -> !torch.tensor<[],f32> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[NONVAL:.*]] = torch.copy.tensor %[[VTENSOR]] : !torch.vtensor<[],f32> -> !torch.tensor<[],f32>
|
||||
// CHECK: return %[[NONVAL]] : !torch.tensor<[],f32>
|
||||
func @torch.tensor$nonval() -> !torch.tensor<[],f32> {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<f32>) : !torch.tensor<[],f32>
|
||||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
// RUN: npcomp-opt -split-input-file %s -numpy-array-to-tensor | FileCheck --dump-input=fail %s
|
||||
|
||||
// Basic case that can be resolved with local reasoning.
|
||||
// This pass will eventually need to learn about aliasing relationships.
|
||||
//
|
||||
// This is taken from a test case from an e2e spike, and isn't intended to be
|
||||
// particularly minimal or specifically test one thing, since the pass is
|
||||
// currently just a handful of canonicalization patterns that are already
|
||||
// tested elsewhere.
|
||||
|
||||
// CHECK-LABEL: func @local(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK: %[[ERASED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tanh %[[ERASED]] : tensor<*x!numpy.any_dtype> -> tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[RET]] : tensor<*x!numpy.any_dtype>
|
||||
func @local(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%0 = numpy.create_array_from_tensor %arg0 : (tensor<2x3x?xf32>) -> !numpy.ndarray<[2,3,?]:f32>
|
||||
%1 = numpy.static_info_cast %0 : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%2 = numpy.copy_to_tensor %1 : (!numpy.ndarray<*:!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype>
|
||||
%3 = torch.aten.tanh %2 : tensor<*x!numpy.any_dtype> -> tensor<*x!numpy.any_dtype>
|
||||
%4 = numpy.create_array_from_tensor %3 : (tensor<*x!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%5 = numpy.copy_to_tensor %4 : (!numpy.ndarray<*:!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype>
|
||||
return %5 : tensor<*x!numpy.any_dtype>
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
// RUN: npcomp-opt -split-input-file %s -verify-diagnostics -allow-unregistered-dialect -numpy-refine-public-return | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<?xf32>) {
|
||||
// CHECK: %[[CTRUE:.*]] = constant true
|
||||
// CHECK: %[[CAST:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[ARG]], %[[CTRUE]], %[[ARG]] : tensor<?xf32>, i1, tensor<?xf32>
|
||||
func @basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
|
||||
%ctrue = std.constant true
|
||||
%cast = numpy.tensor_static_info_cast %arg0 : tensor<?xf32> to tensor<*x!numpy.any_dtype>
|
||||
return %arg0, %ctrue, %cast : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// No conversion on private function.
|
||||
// CHECK-LABEL: func private @basic_private(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
|
||||
// CHECK: %[[CTRUE:.*]] = constant true
|
||||
// CHECK: %[[CASTED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[ARG]], %[[CTRUE]], %[[CASTED]] : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
|
||||
func private @basic_private(%arg0: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
|
||||
%ctrue = std.constant true
|
||||
%cast = numpy.tensor_static_info_cast %arg0 : tensor<?xf32> to tensor<*x!numpy.any_dtype>
|
||||
return %arg0, %ctrue, %cast : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Call to public function.
|
||||
// expected-error @+1 {{unimplemented}}
|
||||
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Multiple returns.
|
||||
// expected-error @+1 {{unimplemented}}
|
||||
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%ctrue = constant true
|
||||
cond_br %ctrue, ^bb1, ^bb2
|
||||
^bb1:
|
||||
return %arg0 : tensor<*xf32>
|
||||
^bb2:
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
|
@ -17,27 +17,25 @@
|
|||
// CHECK: torch.global_slot.init %[[INIT]] : f64
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: torch.global_slot @a : !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: %[[C:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
// CHECK: %[[A:.*]] = numpy.create_array_from_tensor %[[C]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: torch.global_slot.init %[[A]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK-LABEL: torch.global_slot @t : !torch.tensor {
|
||||
// CHECK: %[[T:.*]] = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
|
||||
// CHECK: }
|
||||
|
||||
torch.class_type @c {
|
||||
torch.attr "b" : !basicpy.BoolType
|
||||
torch.attr "i" : i64
|
||||
torch.attr "f" : f64
|
||||
torch.attr "a" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "t" : !torch.tensor
|
||||
}
|
||||
|
||||
%bool_true = basicpy.bool_constant true
|
||||
%i = basicpy.numeric_constant 3 : i64
|
||||
%f = basicpy.numeric_constant 4.250000e+01 : f64
|
||||
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
%a = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%t = torch.tensor(dense<1.0> : tensor<1xf32>) : !torch.tensor
|
||||
torch.nn_module {
|
||||
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||
torch.slot "i", %i : i64
|
||||
torch.slot "f", %f : f64
|
||||
torch.slot "a", %a : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "t", %t : !torch.tensor
|
||||
} : !torch.nn.Module<"c">
|
||||
|
|
|
@ -32,14 +32,13 @@ torch.class_type @parent {
|
|||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
torch.attr "a1" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "a2" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "t1" : !torch.tensor
|
||||
torch.attr "t2" : !torch.tensor
|
||||
}
|
||||
|
||||
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}}
|
||||
%a = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%t = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
torch.nn_module {
|
||||
torch.slot "a1", %a : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "a2", %a : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "t1", %t : !torch.tensor
|
||||
torch.slot "t2", %t : !torch.tensor
|
||||
} : !torch.nn.Module<"c">
|
||||
|
|
|
@ -1,29 +1,32 @@
|
|||
// RUN: npcomp-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: %[[RET:.*]] = numpy.static_info_cast %[[ARG]] : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
func @basic(%arg0: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[2,3,?]:f32>}) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[ERASED]] : !torch.vtensor -> !torch.tensor
|
||||
// CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor
|
||||
func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
|
||||
return %arg0 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @no_type_bound(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: return %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
func @no_type_bound(%arg0: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: return %[[ARG]] : !torch.tensor
|
||||
func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
|
||||
return %arg0 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @call(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.static_info_cast %[[ARG]] : !numpy.ndarray<[2,3,?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: %[[SHAPED:.*]] = numpy.static_info_cast %[[SHAPE_ERASED]] : !numpy.ndarray<*:!numpy.any_dtype> to !numpy.ndarray<[2,3,?]:f32>
|
||||
// CHECK: %[[RES:.*]] = call @call(%[[SHAPED]]) : (!numpy.ndarray<[2,3,?]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
func @call(%arg0: !numpy.ndarray<*:!numpy.any_dtype> {torch.type_bound = !numpy.ndarray<[2,3,?]:f32>}) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
%0 = call @call(%arg0) : (!numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.tensor %[[ARG_ERASED]] : !torch.vtensor -> !torch.tensor
|
||||
// CHECK: %[[INFO_ADDED:.*]] = torch.tensor_static_info_cast %[[ARG_NONVAL]] : !torch.tensor to !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[CALL_ARG:.*]] = torch.copy.tensor %[[INFO_ADDED]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor
|
||||
// CHECK: return %[[ARG_NONVAL]] : !torch.tensor
|
||||
func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
|
||||
%0 = call @call(%arg0) : (!torch.tensor) -> !torch.tensor
|
||||
return %arg0 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @none_return() {
|
||||
|
|
|
@ -7,21 +7,71 @@ func @torch.aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !
|
|||
%0 = torch.aten.__is__ %arg0, %arg1 : !basicpy.ListType, !basicpy.NoneType -> !basicpy.BoolType
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
||||
// CHECK-LABEL: func @torch.aten.size(
|
||||
// CHECK: %[[CM1:.*]] = constant -1 : i64
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.size$canonicalize_to_list(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !basicpy.ListType {
|
||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[C3:.*]] = constant 3 : i64
|
||||
// CHECK: %[[RET:.*]] = basicpy.build_list %[[CM1]], %[[C3]] : (i64, i64) -> !basicpy.ListType
|
||||
// CHECK: return %[[RET]] : !basicpy.ListType
|
||||
func @torch.aten.size(%arg0: tensor<?x3xf32>) -> !basicpy.ListType {
|
||||
%0 = torch.aten.size %arg0 : tensor<?x3xf32> -> !basicpy.ListType
|
||||
// CHECK: %[[LIST:.*]] = basicpy.build_list %[[C2]], %[[C3]] : (i64, i64) -> !basicpy.ListType
|
||||
// CHECK: return %[[LIST]] : !basicpy.ListType
|
||||
func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !basicpy.ListType {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !basicpy.ListType
|
||||
return %0 : !basicpy.ListType
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.len.t(
|
||||
// CHECK: %[[LENGTH:.*]] = constant 2 : i64
|
||||
// CHECK: return %[[LENGTH]] : i64
|
||||
func @torch.aten.len.t(%arg0: i64) -> i64 {
|
||||
%0 = basicpy.build_list %arg0, %arg0 : (i64, i64) -> !basicpy.ListType
|
||||
// One size unknown, so cannot canonicalize.
|
||||
// TODO: For unknown sizes, insert the equivalent of a "dim" op.
|
||||
// Then this will only require static rank.
|
||||
// CHECK-LABEL: func @torch.aten.size$unknown_size(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !basicpy.ListType {
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !basicpy.ListType
|
||||
func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !basicpy.ListType {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !basicpy.ListType
|
||||
return %0 : !basicpy.ListType
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.len.t$of_size(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> i64 {
|
||||
// CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64
|
||||
// CHECK: return %[[DIM]] : i64
|
||||
func @torch.aten.len.t$of_size(%arg0: !torch.vtensor<*,f32>) -> i64 {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !basicpy.ListType
|
||||
%1 = torch.aten.len.t %0 : !basicpy.ListType -> i64
|
||||
return %1 : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.dim$with_shape(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?,?],f32>) -> i64 {
|
||||
// CHECK: %[[DIM:.*]] = constant 3 : i64
|
||||
// CHECK: return %[[DIM]] : i64
|
||||
func @torch.aten.dim$with_shape(%arg0: !torch.vtensor<[?,?,?],f32>) -> i64 {
|
||||
%0 = torch.aten.dim %arg0 : !torch.vtensor<[?,?,?],f32> -> i64
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.len.t$of_build_list(
|
||||
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
|
||||
// CHECK: %[[LEN:.*]] = constant 4 : i64
|
||||
// CHECK: return %[[LEN]] : i64
|
||||
func @torch.aten.len.t$of_build_list(%arg0: i64) -> i64 {
|
||||
%0 = basicpy.build_list %arg0, %arg0, %arg0, %arg0 : (i64, i64, i64, i64) -> !basicpy.ListType
|
||||
%1 = torch.aten.len.t %0 : !basicpy.ListType -> i64
|
||||
return %1 : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.copy.tensor$value_copy_is_noop(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||
// CHECK: return %[[ARG]] : !torch.vtensor
|
||||
func @torch.copy.tensor$value_copy_is_noop(%arg0: !torch.vtensor) -> !torch.vtensor {
|
||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||
// CHECK: return %[[ARG]] : !torch.vtensor
|
||||
func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
|
||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
|
||||
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
// RUN: npcomp-opt %s -torch-finalizing-builtin-tensorize -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s
|
||||
|
||||
// This test is largely copied from `finalizing-bufferize` upstream, as it
|
||||
// covers the same scope.
|
||||
|
||||
// CHECK-LABEL: func @eliminate_materializations(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: return %[[ARG]] : tensor<f32>
|
||||
func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = torch.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
%1 = torch.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
|
||||
// expected-error @+1 {{failed to legalize operation 'test.source'}}
|
||||
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
|
||||
%1 = torch.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
|
||||
%0 = torch.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
|
||||
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()
|
||||
return
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
// RUN: npcomp-opt %s -torch-func-builtin-tensorize -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s
|
||||
|
||||
// This test is largely copied from `func-bufferize` upstream, as it covers
|
||||
// the same scope.
|
||||
|
||||
// CHECK-LABEL: func @identity(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %[[TENSOR:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: return %[[MEMREF]] : tensor<f32>
|
||||
func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
return %arg0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @block_arguments(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %[[T1:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[M1:.*]] = torch.to_builtin_tensor %[[T1]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: br ^bb1(%[[M1]] : tensor<f32>)
|
||||
// CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>):
|
||||
// CHECK: %[[T2:.*]] = torch.from_builtin_tensor %[[BBARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[M2:.*]] = torch.to_builtin_tensor %[[T2]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: return %[[M2]] : tensor<f32>
|
||||
func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
br ^bb1(%arg0: !torch.vtensor<[],f32>)
|
||||
^bb1(%bbarg: !torch.vtensor<[],f32>):
|
||||
return %bbarg : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @source() -> tensor<f32>
|
||||
// CHECK-LABEL: func @call_source() -> tensor<f32> {
|
||||
// CHECK: %[[RET:.*]] = call @source() : () -> tensor<f32>
|
||||
// CHECK: return %[[RET]] : tensor<f32>
|
||||
func private @source() -> !torch.vtensor<[],f32>
|
||||
func @call_source() -> !torch.vtensor<[],f32> {
|
||||
%0 = call @source() : () -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
// CHECK-LABEL: func @call_sink(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) {
|
||||
// CHECK: %[[TENSOR:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: call @sink(%[[MEMREF]]) : (tensor<f32>) -> ()
|
||||
// CHECK: return
|
||||
func private @sink(!torch.vtensor<[],f32>)
|
||||
func @call_sink(%arg0: !torch.vtensor<[],f32>) {
|
||||
call @sink(%arg0) : (!torch.vtensor<[],f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unconverted_op_in_body() -> tensor<f32> {
|
||||
// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: return %[[MEMREF]] : tensor<f32>
|
||||
func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
|
||||
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Because this pass updates block arguments, it needs to also atomically
|
||||
// update all terminators and issue an error if that is not possible.
|
||||
func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
%0 = constant true
|
||||
cond_br %0, ^bb1(%arg0: !torch.vtensor<[],f32>), ^bb2(%arg0: !torch.vtensor<[],f32>)
|
||||
^bb1(%bbarg0: !torch.vtensor<[],f32>):
|
||||
// expected-error @+1 {{failed to legalize operation 'test.terminator'}}
|
||||
"test.terminator"() : () -> ()
|
||||
^bb2(%bbarg1: !torch.vtensor<[],f32>):
|
||||
return %bbarg1 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// There was a bug in func-bufferize pass which caused terminators without
|
||||
// ReturnLike and BranchOpInterface traits (e.g. scf.condition) to always
|
||||
// fail to legalize even if bufferization doesn't needed.
|
||||
// Check the pass succedeed.
|
||||
// CHECK: while
|
||||
// CHECK: scf.while
|
||||
// CHECK: scf.condition
|
||||
func @bwhile(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%c2_i64 = constant 2 : i64
|
||||
%0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) {
|
||||
%1 = cmpi slt, %arg2, %arg1 : i64
|
||||
scf.condition(%1) %arg2, %arg2 : i64, i64
|
||||
} do {
|
||||
^bb0(%arg2: i64, %arg3: i64):
|
||||
%1 = muli %arg3, %c2_i64 : i64
|
||||
scf.yield %1 : i64
|
||||
}
|
||||
return %0#1 : i64
|
||||
}
|
|
@ -1,41 +1,37 @@
|
|||
// RUN: npcomp-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-NOT: @readonly
|
||||
torch.global_slot "private" @readonly : !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
%cst = constant dense<0.0> : tensor<1xf32>
|
||||
%0 = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.global_slot.init %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.global_slot "private" @readonly : !torch.tensor {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<1xf32>) : !torch.tensor
|
||||
torch.global_slot.init %0 : !torch.tensor
|
||||
}
|
||||
// CHECK-LABEL: torch.global_slot @public
|
||||
torch.global_slot @public : !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
%cst = constant dense<0.0> : tensor<2xf32>
|
||||
%0 = numpy.create_array_from_tensor %cst : (tensor<2xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.global_slot.init %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.global_slot @public : !torch.tensor {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<2xf32>) : !torch.tensor
|
||||
torch.global_slot.init %0 : !torch.tensor
|
||||
}
|
||||
// CHECK-LABEL: torch.global_slot "private" @mutated
|
||||
torch.global_slot "private" @mutated : !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
%cst = constant dense<0.0> : tensor<3xf32>
|
||||
%0 = numpy.create_array_from_tensor %cst : (tensor<3xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.global_slot.init %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.global_slot "private" @mutated : !torch.tensor {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<3xf32>) : !torch.tensor
|
||||
torch.global_slot.init %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @forward() -> (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) {
|
||||
func @forward() -> (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) {
|
||||
// CHECK-LABEL: func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
|
||||
func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
|
||||
// Inlined.
|
||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<1xf32>
|
||||
// CHECK: %[[ARRAY_CST:.*]] = numpy.create_array_from_tensor %[[CST]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%0 = torch.global_slot.get @readonly : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: %[[READONLY:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
%0 = torch.global_slot.get @readonly : !torch.tensor
|
||||
|
||||
// Not inlined: potentially mutated by externals.
|
||||
// CHECK: %[[PUBLIC:.*]] = torch.global_slot.get @public : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%1 = torch.global_slot.get @public : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: %[[PUBLIC:.*]] = torch.global_slot.get @public : !torch.tensor
|
||||
%1 = torch.global_slot.get @public : !torch.tensor
|
||||
|
||||
// Not inlined: potentially mutated internally.
|
||||
// CHECK: torch.global_slot.set @mutated = %[[ARRAY_CST]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: %[[MUTATED:.*]] = torch.global_slot.get @mutated : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.global_slot.set @mutated = %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%2 = torch.global_slot.get @mutated : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: torch.global_slot.set @mutated = %[[READONLY]] : !torch.tensor
|
||||
// CHECK: %[[MUTATED:.*]] = torch.global_slot.get @mutated : !torch.tensor
|
||||
torch.global_slot.set @mutated = %0 : !torch.tensor
|
||||
%2 = torch.global_slot.get @mutated : !torch.tensor
|
||||
|
||||
// CHECK: return %[[ARRAY_CST]], %[[PUBLIC]], %[[MUTATED]] : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %0, %1, %2 : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: return %[[READONLY]], %[[PUBLIC]], %[[MUTATED]] : !torch.tensor, !torch.tensor, !torch.tensor
|
||||
return %0, %1, %2 : !torch.tensor, !torch.tensor, !torch.tensor
|
||||
}
|
||||
|
|
|
@ -100,8 +100,8 @@ torch.class_type @c {
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !numpy.ndarray type}}
|
||||
func @f(%arg0: i32 {torch.type_bound = !numpy.ndarray<*:f32>})
|
||||
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
|
||||
func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -110,7 +110,7 @@ func @f(%arg0: i32 {torch.type_bound = 1})
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{'torch.type_bound' must be of !numpy.ndarray type}}
|
||||
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
|
||||
func @f(%arg0: i32 {torch.type_bound = i32})
|
||||
|
||||
// -----
|
||||
|
@ -120,3 +120,35 @@ func @derefine(%arg0: !torch.optional<tensor<f32>>) -> tensor<f32> {
|
|||
%0 = torch.derefine %arg0 : !torch.optional<tensor<f32>> to tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}}
|
||||
func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
|
||||
|
||||
// -----
|
||||
|
||||
func @torch.tensor() {
|
||||
// Incompatible shape.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
%0 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @torch.tensor() {
|
||||
// Incompatible dtype.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
%0 = torch.tensor(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @torch.tensor() {
|
||||
// Incompatible type.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
%0 = torch.tensor(dense<42.0> : tensor<f32>) : i1
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
// RUN: npcomp-opt -split-input-file %s -torch-maximize-value-semantics | FileCheck %s
|
||||
|
||||
// Basic case that can be resolved with local reasoning.
|
||||
// This pass will eventually need to learn about aliasing relationships.
|
||||
//
|
||||
// This is taken from a test case from an e2e spike, and isn't intended to be
|
||||
// particularly minimal or specifically test one thing, since the pass is
|
||||
// currently just a handful of canonicalization patterns that are already
|
||||
// tested elsewhere.
|
||||
|
||||
// CHECK-LABEL: func @local(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[2,3,?],f32>
|
||||
func @local(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
|
||||
%1 = torch.aten.tanh %0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
%2 = torch.copy.tensor %1 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
%3 = torch.tensor_static_info_cast %2 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
%4 = torch.copy.tensor %2 : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
return %4 : !torch.vtensor<[2,3,?],f32>
|
||||
}
|
|
@ -1,28 +1,64 @@
|
|||
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.operator(
|
||||
func @torch.operator(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%0 = torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor
|
||||
%0 = torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
func @torch.linear_params.create(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>) -> (!torch.LinearParams, !torch.LinearParams) {
|
||||
%with_bias = torch.linear_params.create %arg0, %arg1 : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%without_bias = torch.linear_params.create %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
func @torch.linear_params.create(%arg0: !torch.tensor, %arg1: !torch.tensor) -> (!torch.LinearParams, !torch.LinearParams) {
|
||||
%with_bias = torch.linear_params.create %arg0, %arg1 : !torch.tensor, !torch.tensor
|
||||
%without_bias = torch.linear_params.create %arg0 : !torch.tensor
|
||||
return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams
|
||||
}
|
||||
|
||||
func @derefine(%arg0: tensor<f32>) -> !torch.optional<tensor<f32>> {
|
||||
%0 = torch.derefine %arg0 : tensor<f32> to !torch.optional<tensor<f32>>
|
||||
return %0 : !torch.optional<tensor<f32>>
|
||||
// CHECK-LABEL: func @builtin_tensor_interop(
|
||||
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xsi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
|
||||
// CHECK: torch.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
|
||||
%0 = torch.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: torch.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
|
||||
%1 = torch.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
|
||||
// CHECK: torch.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
|
||||
%2 = torch.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
|
||||
// CHECK: torch.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
|
||||
%3 = torch.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: @tensor.default() -> !torch.tensor
|
||||
func private @tensor.default() -> !torch.tensor
|
||||
// CHECK: @tensor.default_explicit() -> !torch.tensor{{$}}
|
||||
func private @tensor.default_explicit() -> !torch.tensor<*,unk>
|
||||
// CHECK: @tensor.value_semantic() -> !torch.vtensor{{$}}
|
||||
func private @tensor.value_semantic() -> !torch.vtensor<*,unk>
|
||||
// CHECK: @tensor.dtype() -> !torch.tensor<*,si32>
|
||||
func private @tensor.dtype() -> !torch.tensor<*,si32>
|
||||
// CHECK: @tensor.ranked() -> !torch.tensor<[?,?,?],unk>
|
||||
func private @tensor.ranked() -> !torch.tensor<[?,?,?],unk>
|
||||
// CHECK: @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
|
||||
func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
|
||||
// CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
|
||||
func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor() {
|
||||
func @torch.tensor() {
|
||||
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
|
||||
%0 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
|
||||
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
|
||||
%1 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
|
||||
return
|
||||
}
|
||||
|
||||
func @derefine(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
|
||||
%0 = torch.derefine %arg0 : !torch.tensor to !torch.optional<!torch.tensor>
|
||||
return %0 : !torch.optional<!torch.tensor>
|
||||
}
|
||||
|
||||
%bool_true = basicpy.bool_constant true
|
||||
%num3_i64 = basicpy.numeric_constant 3 : i64
|
||||
%num = basicpy.numeric_constant 4.250000e+01 : f64
|
||||
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
%array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%tensor = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
%none = basicpy.singleton : !basicpy.NoneType
|
||||
func private @f(%arg0: !torch.nn.Module<"test">) {
|
||||
return
|
||||
|
@ -35,7 +71,7 @@ torch.class_type @test {
|
|||
torch.attr "b" : !basicpy.BoolType
|
||||
torch.attr "i" : i64
|
||||
torch.attr "f" : f64
|
||||
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "t" : !torch.tensor
|
||||
torch.attr "submodule" : !torch.nn.Module<"empty">
|
||||
torch.attr "ob" : !torch.optional<!basicpy.BoolType>
|
||||
torch.method "method", @f
|
||||
|
@ -44,7 +80,7 @@ torch.nn_module {
|
|||
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||
torch.slot "i", %num3_i64 : i64
|
||||
torch.slot "f", %num : f64
|
||||
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "t", %tensor : !torch.tensor
|
||||
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
|
||||
torch.slot "ob", %none : !basicpy.NoneType
|
||||
} : !torch.nn.Module<"test">
|
||||
|
|
|
@ -1,33 +1,33 @@
|
|||
// RUN: npcomp-opt -torch-reduce-op-variants %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @convert_to_immutable_tensors(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<[]:f32>) -> !numpy.ndarray<[]:f32> {
|
||||
// CHECK: %[[OPERAND_TENSOR:.*]] = numpy.copy_to_tensor %[[ARG]] : (!numpy.ndarray<[]:f32>) -> tensor<f32>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : tensor<f32> -> tensor<f32>
|
||||
// CHECK: %[[RET:.*]] = numpy.create_array_from_tensor %[[RESULT_TENSOR]] : (tensor<f32>) -> !numpy.ndarray<[]:f32>
|
||||
// CHECK: return %[[RET]] : !numpy.ndarray<[]:f32>
|
||||
func @convert_to_immutable_tensors(%arg0: !numpy.ndarray<[]:f32>) -> !numpy.ndarray<[]:f32> {
|
||||
%0 = torch.aten.tanh %arg0 : !numpy.ndarray<[]:f32> -> !numpy.ndarray<[]:f32>
|
||||
return %0 : !numpy.ndarray<[]:f32>
|
||||
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.tensor %[[ARG]] : !torch.tensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RET:.*]] = torch.copy.tensor %[[RESULT_TENSOR]] : !torch.vtensor<[],f32> -> !torch.tensor<[],f32>
|
||||
// CHECK: return %[[RET]] : !torch.tensor<[],f32>
|
||||
func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
%0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
|
||||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @reduce_trailing_underscore_inplace_variant(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !numpy.ndarray<[2,2]:f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||
// CHECK: %[[VAL_2:.*]] = constant 1 : i64
|
||||
// CHECK: %[[TENSOR0:.*]] = numpy.copy_to_tensor %[[ARG0]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[TENSOR1:.*]] = numpy.copy_to_tensor %[[ARG1]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[TENSOR_RESULT:.*]] = torch.aten.add.Tensor %[[TENSOR0]], %[[TENSOR1]], %[[VAL_2]] : tensor<2x2xf32>, tensor<2x2xf32>, i64 -> tensor<2x2xf32>
|
||||
// Note: This somewhat redundant tensor->array->tensor conversion
|
||||
// CHECK: %[[TENSOR0:.*]] = torch.copy.tensor %[[ARG0]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR1:.*]] = torch.copy.tensor %[[ARG1]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR_RESULT:.*]] = torch.aten.add.Tensor %[[TENSOR0]], %[[TENSOR1]], %[[VAL_2]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, i64 -> !torch.vtensor<[2,2],f32>
|
||||
// Note: This somewhat redundant conversion back and forth
|
||||
// (which is cleaned up by canonicalization) is an artifact of two patterns
|
||||
// being applied in sequence.
|
||||
// CHECK: %[[ARRAY_RESULT:.*]] = numpy.create_array_from_tensor %[[TENSOR_RESULT]] : (tensor<2x2xf32>) -> !numpy.ndarray<[2,2]:f32>
|
||||
// CHECK: %[[TENSOR_AGAIN:.*]] = numpy.copy_to_tensor %[[ARRAY_RESULT]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
|
||||
// CHECK: numpy.overwrite_array %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : tensor<2x2xf32>, !numpy.ndarray<[2,2]:f32>
|
||||
// CHECK: return %[[ARG0]], %[[ARG0]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||
func @reduce_trailing_underscore_inplace_variant(%arg0: !numpy.ndarray<[2,2]:f32>, %arg1: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
|
||||
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.tensor %[[TENSOR_RESULT]] : !torch.vtensor<[2,2],f32> -> !torch.tensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.tensor %[[ARRAY_RESULT]] : !torch.tensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
||||
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||
%c1 = constant 1 : i64
|
||||
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>, i64 -> !numpy.ndarray<[2,2]:f32>
|
||||
return %0, %arg0 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, i64 -> !torch.tensor<[2,2],f32>
|
||||
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
}
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
||||
// CHECK: %[[COPIED_NONVAL:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[COPIED_VALUE:.*]] = torch.copy.tensor %[[COPIED_NONVAL]] : !torch.tensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: return %[[COPIED_VALUE]] : !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: }
|
||||
func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
return %2 : !torch.tensor
|
||||
}
|
||||
|
||||
// No conversion on private function.
|
||||
// CHECK-LABEL: func private @basic_private(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[COPIED:.*]] = torch.copy.tensor %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CASTED]] : !torch.tensor
|
||||
func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
%1 = torch.copy.tensor %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
return %2 : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Call to public function.
|
||||
// expected-error @+1 {{unimplemented}}
|
||||
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Multiple returns.
|
||||
// expected-error @+1 {{unimplemented}}
|
||||
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%ctrue = constant true
|
||||
cond_br %ctrue, ^bb1, ^bb2
|
||||
^bb1:
|
||||
return %arg0 : tensor<*xf32>
|
||||
^bb2:
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
|
@ -1,88 +1,102 @@
|
|||
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK: %[[SHAPED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<2x3x?xf32> to tensor<2x3x?xf32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[SHAPED]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%0 = numpy.tensor_static_info_cast %arg0 : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
return %0 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[SHAPED]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK: %[[SHAPED:.*]] = torch.aten.tanh %[[ARG]] : tensor<2x3x?xf32> -> tensor<2x3x?xf32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[SHAPED]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%1 = torch.aten.tanh %arg0 : tensor<2x3x?xf32> -> tensor<*x!numpy.any_dtype>
|
||||
return %1 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.tensor %[[CASTED]] : !torch.vtensor<[2,3,?],f32> -> !torch.tensor<[2,3,?],f32>
|
||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[NONVAL_TENSOR]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[ERASED]] : !torch.tensor
|
||||
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
%1 = torch.copy.tensor %0 : !torch.vtensor -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<2x?xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[MM]] : tensor<?x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0: tensor<2x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%1 = torch.aten.mm %arg0, %arg1 : tensor<2x?xf32>, tensor<?x?xf32> -> tensor<*x!numpy.any_dtype>
|
||||
return %1 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[SHAPED:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[SHAPED]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x3xf32>,
|
||||
// CHECK-SAME: %[[WEIGHT:.*]]: tensor<5x3xf32>,
|
||||
// CHECK-SAME: %[[BIAS:.*]]: tensor<5xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[INPUT]], %[[WEIGHT]], %[[BIAS]] : tensor<?x3xf32>, tensor<5x3xf32>, tensor<5xf32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[LINEAR]] : tensor<?x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0: tensor<?x3xf32>, %arg1: tensor<5x3xf32>, %arg2: tensor<5xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%1 = torch.aten.linear %arg0, %arg1, %arg2 : tensor<?x3xf32>, tensor<5x3xf32>, tensor<5xf32> -> tensor<*x!numpy.any_dtype>
|
||||
return %1 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[?,?],f32> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,3],f32>,
|
||||
// CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[5,3],f32>,
|
||||
// CHECK-SAME: %[[BIAS:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[INPUT]], %[[WEIGHT]], %[[BIAS]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<[?,?],f32> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
func @f(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
// CHECK: %[[CONV2D:.*]] = torch.aten.conv2d{{.*}} -> tensor<?x?x?x?x!numpy.any_dtype>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[CONV2D]] : tensor<?x?x?x?x!numpy.any_dtype> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0:tensor<*x!numpy.any_dtype>, %arg1:tensor<*x!numpy.any_dtype>, %arg2:tensor<*x!numpy.any_dtype>) ->tensor<*x!numpy.any_dtype> {
|
||||
// CHECK: %[[CONV2D:.*]] = torch.aten.conv2d{{.*}} -> !torch.vtensor<[?,?,?,?],unk>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[CONV2D]] : !torch.vtensor<[?,?,?,?],unk> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!torch.vtensor {
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->tensor<*x!numpy.any_dtype>
|
||||
return %3 :tensor<*x!numpy.any_dtype>
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->!torch.vtensor
|
||||
return %3 :!torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @g
|
||||
// CHECK: %[[CONV2D:.*]] = torch.aten.conv2d{{.*}} -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[CONV2D]] : tensor<?x?x?x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @g(%arg0:tensor<*xf32>, %arg1:tensor<*xf32>, %arg2:tensor<*xf32>) ->tensor<*x!numpy.any_dtype> {
|
||||
// CHECK: %[[CONV2D:.*]] = torch.aten.conv2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[CONV2D]] : !torch.vtensor<[?,?,?,?],f32> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor {
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->tensor<*x!numpy.any_dtype>
|
||||
return %3 :tensor<*x!numpy.any_dtype>
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->!torch.vtensor
|
||||
return %3 :!torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
|
||||
%c1_i64 = constant 1 : i64
|
||||
%c3_i64 = constant 3 : i64
|
||||
%c2_i64 = constant 2 : i64
|
||||
|
@ -91,78 +105,78 @@ func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
|||
%22 = basicpy.build_list %c2_i64, %c2_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%23 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%24 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
// CHECK: torch.aten.max_pool2d{{.*}} -> tensor<?x?x?x?xf32>
|
||||
%27 = torch.aten.max_pool2d %arg0, %21, %22, %23, %24, %bool_false : tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType -> tensor<*x!numpy.any_dtype>
|
||||
return %27 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
|
||||
%27 = torch.aten.max_pool2d %arg0, %21, %22, %23, %24, %bool_false : !torch.vtensor<[?,?,?,?],f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType -> !torch.vtensor
|
||||
return %27 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<?x?x?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
// CHECK: torch.aten.adaptive_avg_pool2d{{.*}} -> tensor<?x?x?x?xf32>
|
||||
%1 = torch.aten.adaptive_avg_pool2d %arg0, %0 : tensor<?x?x?x?xf32>, !basicpy.ListType -> tensor<*x!numpy.any_dtype>
|
||||
return %1 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK: torch.aten.adaptive_avg_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
|
||||
%1 = torch.aten.adaptive_avg_pool2d %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !basicpy.ListType -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Also test cast insertion for array types.
|
||||
// CHECK-LABEL: func @flatten_all(
|
||||
// CHECK: %[[FLATTENED:.*]] = torch.aten.flatten.using_ints{{.*}}-> !numpy.ndarray<[?]:f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.static_info_cast %[[FLATTENED]] : !numpy.ndarray<[?]:f32> to !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: %[[FLATTENED:.*]] = torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[?],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[FLATTENED]] : !torch.tensor<[?],f32> to !torch.tensor
|
||||
// CHECK: return %[[SHAPE_ERASED]]
|
||||
func @flatten_all(%arg0: !numpy.ndarray<[3,2,?,5]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
func @flatten_all(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
|
||||
%end = constant -1 : i64
|
||||
%start = constant 0 : i64
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !numpy.ndarray<[3,2,?,5]:f32>, i64, i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !torch.tensor<[3,2,?,5],f32>, i64, i64 -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @flatten_some(
|
||||
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !numpy.ndarray<[3,?,5]:f32>
|
||||
func @flatten_some(%arg0: !numpy.ndarray<[3,2,?,5]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[3,?,5],f32>
|
||||
func @flatten_some(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
|
||||
%end = constant -2 : i64
|
||||
%start = constant 1 : i64
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !numpy.ndarray<[3,2,?,5]:f32>, i64, i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !torch.tensor<[3,2,?,5],f32>, i64, i64 -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @flatten_rank0(
|
||||
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !numpy.ndarray<[?]:f32>
|
||||
func @flatten_rank0(%arg0: !numpy.ndarray<[]:f32>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[?],f32>
|
||||
func @flatten_rank0(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
|
||||
%end = constant -1 : i64
|
||||
%start = constant 0 : i64
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !numpy.ndarray<[]:f32>, i64, i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !torch.tensor<[],f32>, i64, i64 -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<4x6x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: tensor<?x3xf32>) {
|
||||
func @f(%arg0: !torch.vtensor<[4,6,3],f32>, %arg1: !torch.vtensor<[1,1,3],f32>, %arg2: !torch.vtensor<[?,3],f32>) {
|
||||
%c1_i64 = constant 1 : i64
|
||||
// CHECK: torch.aten.add{{.*}} -> tensor<?x?x?xf32>
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %c1_i64 : tensor<4x6x3xf32>, tensor<1x1x3xf32>, i64 -> tensor<*x!numpy.any_dtype>
|
||||
// CHECK: torch.aten.add{{.*}} -> tensor<?x?x?xf32>
|
||||
%1 = torch.aten.add.Tensor %arg0, %arg2, %c1_i64 : tensor<4x6x3xf32>, tensor<?x3xf32>, i64 -> tensor<*x!numpy.any_dtype>
|
||||
// CHECK: torch.aten.add{{.*}} -> !torch.vtensor<[?,?,?],f32>
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %c1_i64 : !torch.vtensor<[4,6,3],f32>, !torch.vtensor<[1,1,3],f32>, i64 -> !torch.vtensor
|
||||
// CHECK: torch.aten.add{{.*}} -> !torch.vtensor<[?,?,?],f32>
|
||||
%1 = torch.aten.add.Tensor %arg0, %arg2, %c1_i64 : !torch.vtensor<[4,6,3],f32>, !torch.vtensor<[?,3],f32>, i64 -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
||||
// Check propagation through multiple ops.
|
||||
// CHECK: torch.aten.tanh %{{.*}} : tensor<2x3x?xf32> -> tensor<2x3x?xf32>
|
||||
// CHECK: torch.aten.tanh %{{.*}} : tensor<2x3x?xf32> -> tensor<2x3x?xf32>
|
||||
// CHECK: torch.aten.tanh %{{.*}} : tensor<2x3x?xf32> -> tensor<2x3x?xf32>
|
||||
%1 = torch.aten.tanh %arg0 : tensor<2x3x?xf32> -> tensor<*x!numpy.any_dtype>
|
||||
%2 = torch.aten.tanh %1 : tensor<*x!numpy.any_dtype> -> tensor<*x!numpy.any_dtype>
|
||||
%3 = torch.aten.tanh %2 : tensor<*x!numpy.any_dtype> -> tensor<*x!numpy.any_dtype>
|
||||
return %3 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor
|
||||
%2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
|
||||
%3 = torch.aten.tanh %2 : !torch.vtensor -> !torch.vtensor
|
||||
return %3 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -170,44 +184,44 @@ func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
|||
// Check rewriting logic in case of mixes of users that do/don't allow type
|
||||
// refinement.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<2x3x?xf32>) -> (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>) {
|
||||
// CHECK: %[[REFINED_TYPE:.*]] = torch.aten.tanh %{{.*}} : tensor<2x3x?xf32> -> tensor<2x3x?xf32>
|
||||
%1 = torch.aten.tanh %arg0 : tensor<2x3x?xf32> -> tensor<*x!numpy.any_dtype>
|
||||
// CHECK: %[[ORIGINAL_TYPE:.*]] = numpy.tensor_static_info_cast %[[REFINED_TYPE]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: torch.aten.tanh %[[REFINED_TYPE]] : tensor<2x3x?xf32> -> tensor<2x3x?xf32>
|
||||
%3 = torch.aten.tanh %1 : tensor<*x!numpy.any_dtype> -> tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[ORIGINAL_TYPE]], %[[ORIGINAL_TYPE]] : tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>
|
||||
return %1, %1 : tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> (!torch.vtensor, !torch.vtensor) {
|
||||
// CHECK: %[[REFINED_TYPE:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor
|
||||
// CHECK: %[[ORIGINAL_TYPE:.*]] = torch.tensor_static_info_cast %[[REFINED_TYPE]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
// CHECK: torch.aten.tanh %[[REFINED_TYPE]] : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
|
||||
%3 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
|
||||
// CHECK: return %[[ORIGINAL_TYPE]], %[[ORIGINAL_TYPE]] : !torch.vtensor, !torch.vtensor
|
||||
return %1, %1 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
// CHECK: %[[ATEN:.*]] = torch.aten.tanh %{{.*}} : tensor<*x!numpy.any_dtype> -> tensor<2x3x?xf32>
|
||||
// CHECK: %[[CAST:.*]] = numpy.tensor_static_info_cast %[[ATEN]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[CAST]] : tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%cast = numpy.tensor_static_info_cast %arg0 : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
br ^bb1(%cast: tensor<*x!numpy.any_dtype>)
|
||||
^bb1(%arg1: tensor<*x!numpy.any_dtype>):
|
||||
%1 = torch.aten.tanh %arg1 : tensor<*x!numpy.any_dtype> -> tensor<*x!numpy.any_dtype>
|
||||
return %1 : tensor<*x!numpy.any_dtype>
|
||||
// CHECK: %[[ATEN:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<[2,3,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
br ^bb1(%cast: !torch.vtensor)
|
||||
^bb1(%arg1: !torch.vtensor):
|
||||
%1 = torch.aten.tanh %arg1 : !torch.vtensor -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
// CHECK: func private @callee
|
||||
// CHECK-NEXT: torch.aten.tanh %{{.*}} : tensor<*x!numpy.any_dtype> -> tensor<2x3x?xf32>
|
||||
// CHECK-NEXT: torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<[2,3,?],f32>
|
||||
func @f() {
|
||||
module {
|
||||
func private @callee(%arg0: tensor<*x!numpy.any_dtype>) {
|
||||
%1 = torch.aten.tanh %arg0 : tensor<*x!numpy.any_dtype> -> tensor<*x!numpy.any_dtype>
|
||||
func private @callee(%arg0: !torch.vtensor) {
|
||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
func @caller(%arg0: tensor<2x3x?xf32>) {
|
||||
%cast = numpy.tensor_static_info_cast %arg0 : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
call @callee(%cast) : (tensor<*x!numpy.any_dtype>) -> ()
|
||||
func @caller(%arg0: !torch.vtensor<[2,3,?],f32>) {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
call @callee(%cast) : (!torch.vtensor) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue