mirror of https://github.com/llvm/torch-mlir
Add `aten.mm` to linalg lowering.
This is our first op with error semantics, and stresses the system. There are a few design notes of special interest: - RefineTypes.cpp's note about shape inference in the presence of code that dynamically produces and error, and it is provable statically. - ATenToLinalg.cpp's notes about future automation of the ATen->linalg path. - The notes in Passes.td about using low-tech `std.assert` ops instead of `shape.assuming`. Note: Doesn't work on IREE yet due to the `std.assert` op (needs to be lowered to `vm.fail` on the IREE side).pull/207/head
parent
28a0f02746
commit
f5dfa02523
|
@ -0,0 +1,51 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
import test_utils
|
||||
|
||||
logging.enable()
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
test_module = TestModule()
|
||||
class_annotator = torch_mlir.ClassAnnotator()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
torch.jit.save(recursivescriptmodule, '/tmp/foo.pt')
|
||||
|
||||
class_annotator.exportNone(recursivescriptmodule._c._type())
|
||||
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
|
||||
class_annotator.annotateShapesAndDtypes(recursivescriptmodule._c._type(), ['forward'], [
|
||||
None,
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32),
|
||||
])
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
#mb.module.operation.print()
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
torch.manual_seed(0)
|
||||
lhs = torch.rand(2, 3)
|
||||
rhs = torch.rand(3, 4)
|
||||
test_utils.compare_outputs(test_module.forward, jit_module.forward, lhs, rhs)
|
|
@ -0,0 +1,21 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
||||
#define NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertATenToLinalgPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
|
@ -20,6 +20,85 @@ def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> {
|
|||
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
|
||||
}
|
||||
|
||||
def ConvertATenToLinalg : Pass<"convert-aten-to-linalg", "FuncOp"> {
|
||||
let summary = "Convert recognized ATen to Linalg ops";
|
||||
let description = [{
|
||||
Convert ATen ops to linalg ops.
|
||||
|
||||
This pass's main responsibility is to bridge the world between ops
|
||||
that safely terminate the program in case of operand shape mismatches
|
||||
(ATen) and ops where such mismatches are undefined behavior (linalg).
|
||||
|
||||
To model the termination of the program for implementing error guards,
|
||||
we use the `std.assert` op.
|
||||
This is a design decision that is at variance from other passes in npcomp,
|
||||
such as `convert-tcf-to-std` and `convert-tcf-to-linalg` which use the
|
||||
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
|
||||
`shape.assuming` regions). This is a change in design decisions
|
||||
from those passes (which will be subsumed by this one). The reasons for this
|
||||
change are heuristic, but boil down to:
|
||||
1. The modeling of `shape.assuming` is odd, as it uses a region, which is
|
||||
not a good fit for modeling error guards. Regions mark a "start" and an
|
||||
"end" (which is their nesting property). But
|
||||
modeling assertions in the program doesn't fit into that. For assertions,
|
||||
only the "start" matters (once tested, a predicate remains true "forever"
|
||||
-- it doesn't end at the "yield" of the region).
|
||||
Thus, having regions places arbitrary "end"s that just add IR structure
|
||||
that has no semantic value for modeling this problem! (and to make things
|
||||
worse the "end"s, which we don't need, are what require "yielding"
|
||||
values, which interrupts use-def chains). Consider the different
|
||||
structural properties of regions:
|
||||
a. IsolatedFromAbove region:
|
||||
- "start" interrupts use-def chains,
|
||||
- "end" interrupts use-def chains
|
||||
- structurally protects from intra-block upward and downward
|
||||
code motion
|
||||
b. Capturing region (like `shape.assuming`):
|
||||
- "start" does not interrupt use-def chains,
|
||||
- "end" interrupts use-def chains
|
||||
- structurally protects from intra-block upward and downward
|
||||
code motion
|
||||
c. What we "ideally" want:
|
||||
- "start" interrupts use-def chains (can be pruned though)
|
||||
- no "end" IR structure!
|
||||
- structurally protects from intra-block upward code motion
|
||||
(but not downward code motion!)
|
||||
- Observation: We probably can't get all of this, but overall this
|
||||
problem is much better suited for a "MemorySSA"-like
|
||||
abstraction, call it "EffectSSA" which is constructed on-demand
|
||||
based on MLIR's effect modeling system (rather than
|
||||
`shape.assuming`, which only covers the effects the IR creator
|
||||
encoded -- with witnesses/`shape.assuming` -- it is easy to forget
|
||||
to handle effects other than those encoded in the
|
||||
witness structure).
|
||||
2. The presence of `shape.assuming` regions tends to create highly nested
|
||||
IR structures, which don't interoperate well with any other IR
|
||||
structures, and creates very bulky IR (and IR creation code). In general
|
||||
if we are going to do anything with anything (e.g. canonicalize) we
|
||||
end up needing need to either:
|
||||
a. Flatten the `shape.assuming` IR (defeating the purpose of having
|
||||
it).
|
||||
b. Do some sort of shape.assuming "region merging".
|
||||
c. Have special patterns that handle a subset of special cases (looking
|
||||
through "yields" and such) and don't generalize.
|
||||
3. Witnesses tend to encourage non-scalable peephole transformations, which
|
||||
tend to make analyses/transformations non-robust to the presence of
|
||||
control flow and side effecting ops (easy to forget to handle side
|
||||
effects other than those modeled by the witness system).
|
||||
4. All this code operates on ranked tensors, for which using individual
|
||||
SSA values for sizes (rather than a "shape type") seems to
|
||||
work really well at this level of abstraction based on prior experience
|
||||
in IREE. (unranked code tends to benefit from having a discrete
|
||||
"shape type" to model shapes).
|
||||
|
||||
We will see if we end up needing something like `shape.assuming`, but for
|
||||
now, it seems likely we can do something simpler and just bypass it. The
|
||||
design of having an EffectSSA that is constructed on-demand seems very
|
||||
compelling for modeling effects more broadly.
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::createConvertATenToLinalgPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Basicpy conversions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -642,4 +642,3 @@ def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<Torch
|
|||
let results = (outs
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h" // TODO: For `memref.dim`.
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Patterns (as this grows, it should be organized into multiple files)
|
||||
// -----------------------------------------------------------------------------
|
||||
// This is going to eventually be O(#aten ops), which is in the 100s.
|
||||
//
|
||||
// Most of these patterns consist of:
|
||||
// 1. Checking that the operand/result types and other static properties are
|
||||
// good-enough to create a valid linalg op (such as operands being of
|
||||
// ranks/dtypes acceptable to the linalg op).
|
||||
// 2. Creating dynamic error guards, usually checking a predicate on the
|
||||
// compatibility of operand shapes.
|
||||
// 3. Creating init tensors for the computation op. Usually this involves
|
||||
// reifying IR for a shape transfer function based on the operand shapes.
|
||||
// 4. Creating a named linalg op to replace the original op.
|
||||
//
|
||||
// TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such
|
||||
// that these patterns become mostly mechanical associations of
|
||||
// "aten.foo -> linalg.foo".
|
||||
|
||||
static LogicalResult verifyLinalgCompatibleTypes(Operation *op, PatternRewriter &rewriter) {
|
||||
// For now, use a small allowlist of types we don't reject.
|
||||
// The main culprit in practice is that !numpy.any_dtype might be present
|
||||
// if shape/dtype inference wasn't good enough.
|
||||
auto isValidLinalgType = [](Type type) {
|
||||
if (auto rankedTensor = type.dyn_cast<RankedTensorType>()) {
|
||||
if (BaseMemRefType::isValidElementType(rankedTensor.getElementType()))
|
||||
return true;
|
||||
}
|
||||
if (type.isa<FloatType, IntegerType, IndexType>())
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
||||
llvm::all_of(op->getResultTypes(), isValidLinalgType);
|
||||
if (!valid)
|
||||
return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult convertMmOp(aten::MmOp op, PatternRewriter &rewriter) {
|
||||
Location loc = op->getLoc();
|
||||
Value lhs = op.getOperand(0);
|
||||
Value rhs = op.getOperand(1);
|
||||
|
||||
// A user can write an errorneous program where `aten.mm` is in fact called
|
||||
// with operands of invalid rank or dtype. We cannot convert to linalg in this
|
||||
// case or we will get a verifier error, which corresponds to breaking of
|
||||
// *internal* compiler invariants, and for a user manifests as a compiler
|
||||
// crash in the worst case (such as we try to canonicalize/fold/print the
|
||||
// invalid op before the verifier gets to see it -- also release builds of a
|
||||
// mature copmiler usually have the verifier turned off for compile time
|
||||
// reasons).
|
||||
//
|
||||
// The compiler cannot crash even if the user wrote an erroneous program!
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
|
||||
rhs.getType().cast<RankedTensorType>().getRank() != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected both operands to aten.mm to be rank 2");
|
||||
}
|
||||
|
||||
Value lhsDim0 = rewriter.create<memref::DimOp>(loc, lhs, 0);
|
||||
Value lhsDim1 = rewriter.create<memref::DimOp>(loc, lhs, 1);
|
||||
Value rhsDim0 = rewriter.create<memref::DimOp>(loc, rhs, 0);
|
||||
Value rhsDim1 = rewriter.create<memref::DimOp>(loc, rhs, 1);
|
||||
Value contractingDimEqual =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDim1, rhsDim0);
|
||||
rewriter.create<AssertOp>(
|
||||
loc, contractingDimEqual,
|
||||
rewriter.getStringAttr("mismatching contracting dimension for aten.mm"));
|
||||
|
||||
Type elementType = op.getType().cast<TensorType>().getElementType();
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||
Value c0 = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||
Value zeroFill =
|
||||
rewriter.create<linalg::FillOp>(loc, initTensor, c0).getResult(0);
|
||||
Value matmul = rewriter
|
||||
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
|
||||
ValueRange{lhs, rhs}, zeroFill)
|
||||
.getResult(0);
|
||||
// When constructed with just dynamic sizes, InitTensorOp will have a result
|
||||
// type which has all `?`'s for dimensions, which might not be the result
|
||||
// type of `op`. The constraints on later linalg ops means that the result of
|
||||
// the MatmulOp will have this type too. So cast it to the desired type so
|
||||
// that in the end we have the original result type.
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), matmul);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
class ConvertATenToLinalg
|
||||
: public ConvertATenToLinalgBase<ConvertATenToLinalg> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||
}
|
||||
|
||||
FrozenRewritePatternList getPatterns() {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add(convertMmOp);
|
||||
return std::move(patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createConvertATenToLinalgPass() {
|
||||
return std::make_unique<ConvertATenToLinalg>();
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
add_npcomp_conversion_library(NPCOMPATenToLinalg
|
||||
ATenToLinalg.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/ATenToLinalg
|
||||
|
||||
DEPENDS
|
||||
NPCOMPConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
NPCOMPATenDialect
|
||||
MLIRLinalg
|
||||
)
|
|
@ -173,6 +173,5 @@ void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
|
|||
patterns.add<ConvertUnary<aten::TanhOp, tcf::TanhOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
||||
patterns.add<ConvertATenConv2d>(context);
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
add_subdirectory(ATenToLinalg)
|
||||
add_subdirectory(ATenToTCF)
|
||||
add_subdirectory(BasicpyToStd)
|
||||
add_subdirectory(NumpyToTCF)
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "npcomp/Conversion/Passes.h"
|
||||
|
||||
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
|
||||
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
||||
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
||||
|
|
|
@ -75,7 +75,8 @@ bool operator==(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
|||
std::make_tuple(rhs.hasRank, rhs.sizes, rhs.elementType);
|
||||
}
|
||||
|
||||
// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge &knowledge) {
|
||||
// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge
|
||||
// &knowledge) {
|
||||
// os << "hasRank = " << knowledge.hasRank << ", sizes = [";
|
||||
// llvm::interleaveComma(knowledge.sizes, os);
|
||||
// os << "]"
|
||||
|
@ -83,6 +84,16 @@ bool operator==(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
|||
// return os;
|
||||
// }
|
||||
|
||||
Type joinElementTypes(Type lhs, Type rhs) {
|
||||
if (lhs.isa<Numpy::AnyDtypeType>())
|
||||
return rhs;
|
||||
if (rhs.isa<Numpy::AnyDtypeType>())
|
||||
return lhs;
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
return Numpy::AnyDtypeType::get(lhs.getContext());
|
||||
}
|
||||
|
||||
// Given two pieces of static knowledge, calculate conservatively the
|
||||
// information we can be sure about.
|
||||
ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
||||
|
@ -116,13 +127,7 @@ ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
|||
}
|
||||
}
|
||||
|
||||
if (!lhs.elementType || lhs.elementType.isa<Numpy::AnyDtypeType>()) {
|
||||
result.elementType = rhs.elementType;
|
||||
} else if (!rhs.elementType || rhs.elementType.isa<Numpy::AnyDtypeType>()) {
|
||||
result.elementType = lhs.elementType;
|
||||
} else if (lhs.elementType == rhs.elementType) {
|
||||
result.elementType = lhs.elementType;
|
||||
}
|
||||
result.elementType = joinElementTypes(lhs.elementType, rhs.elementType);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -157,6 +162,52 @@ private:
|
|||
DenseMap<Value, ValueKnowledge> facts;
|
||||
};
|
||||
|
||||
// Return the knowledge for the results of an op, based on the knowledge of the
|
||||
// operands and any information intrinsic to `op`.
|
||||
static SmallVector<ValueKnowledge>
|
||||
forwardKnowledgeTransferFunction(Operation *op,
|
||||
ArrayRef<ValueKnowledge> operandKnowledge) {
|
||||
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) {
|
||||
return {operandKnowledge[0]};
|
||||
} else if (isa<aten::MmOp>(op)) {
|
||||
auto &lhs = operandKnowledge[0];
|
||||
auto &rhs = operandKnowledge[1];
|
||||
auto knowledge =
|
||||
ValueKnowledge::getMostConservativeKnowledge(op->getContext());
|
||||
knowledge.hasRank = true;
|
||||
// WARNING: We could be more precise here by calculating the output
|
||||
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
|
||||
// at this stage in the compiler because we don't really have many static
|
||||
// guarantees about the input ranks because `aten` ops do dynamic error
|
||||
// checking and safely abort the program. There is nothing preventing us
|
||||
// from (correctly!) statically inferring the shapes of the operands to
|
||||
// shapes that are guaranteed to cause an error at runtime.
|
||||
//
|
||||
// Example: Suppose a user program calls `aten.mm` with two rank-0 operands.
|
||||
// The program emits an error when invoked, but when running this pass,
|
||||
// we will (correctly!) infer `lhs.hasRank && lhs.sizes.size() == 0 &&
|
||||
// rhs.hasRank && rhs.sizes.size() == 0` -- it's not safe to access
|
||||
// `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer
|
||||
// function, it's not as simple as taking `lhs.sizes[0]` and `rhs.sizes[1]`,
|
||||
// as both of those might read out of bounds of the array. It would require
|
||||
// more complicated logic.
|
||||
//
|
||||
// Just knowing dtypes and ranks is sufficient at this stage
|
||||
// in the compiler. The precise per-dimension size propagation is best done
|
||||
// lower in the stack, such as at the linalg level, where we have more
|
||||
// static guarantees and more structure.
|
||||
knowledge.sizes.resize(2, kUnknownSize);
|
||||
// TODO: Investigate promotion rules if element types mismatch.
|
||||
// This is conservatively correct, assuming that if both element types are
|
||||
// the same, then the result is of that same element type.
|
||||
knowledge.elementType = joinElementTypes(lhs.elementType, rhs.elementType);
|
||||
return {knowledge};
|
||||
}
|
||||
return SmallVector<ValueKnowledge>(
|
||||
op->getNumResults(),
|
||||
ValueKnowledge::getMostConservativeKnowledge(op->getContext()));
|
||||
}
|
||||
|
||||
void TypeAnalyzer::propagate(Region ®ion) {
|
||||
bool changed;
|
||||
do {
|
||||
|
@ -168,10 +219,13 @@ void TypeAnalyzer::propagate(Region ®ion) {
|
|||
for (Operation &op : block.getOperations()) {
|
||||
for (Value v : op.getResults())
|
||||
changed |= incorporateKnowledge(v, getKnowledgeFromType(v.getType()));
|
||||
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) {
|
||||
changed |= incorporateKnowledge(op.getResult(0),
|
||||
getKnowledge(op.getOperand(0)));
|
||||
}
|
||||
auto operandKnowledge = llvm::to_vector<6>(llvm::map_range(
|
||||
op.getOperands(), [&](Value v) { return getKnowledge(v); }));
|
||||
SmallVector<ValueKnowledge> resultKnowledge =
|
||||
forwardKnowledgeTransferFunction(&op, operandKnowledge);
|
||||
assert(resultKnowledge.size() == op.getNumResults());
|
||||
for (auto t : llvm::zip(op.getResults(), resultKnowledge))
|
||||
changed |= incorporateKnowledge(std::get<0>(t), std::get<1>(t));
|
||||
}
|
||||
};
|
||||
} while (changed);
|
||||
|
|
|
@ -61,6 +61,7 @@ TORCH_TO_TCP_PASSES = (
|
|||
# Lower to TCP (+ guards) which is the input to codegen backends.
|
||||
# Most of this should be subsumed by aten->linalg+guards conversions.
|
||||
# (the guard generation will be automated from the linalg Op DSL)
|
||||
"func(convert-aten-to-linalg)",
|
||||
"func(convert-aten-to-tcf)",
|
||||
"func(convert-tcf-to-std)",
|
||||
"func(convert-elementwise-to-linalg)",
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
// RUN: npcomp-opt <%s -convert-aten-to-linalg | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @aten.mm$basic(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x2xf32> {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[LHS_DIM_0:.*]] = memref.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[LHS_DIM_1:.*]] = memref.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[RHS_DIM_0:.*]] = memref.dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[RHS_DIM_1:.*]] = memref.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[EQ:.*]] = cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index
|
||||
// CHECK: assert %[[EQ]], "mismatching contracting dimension for aten.mm"
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[LHS_DIM_0]], %[[RHS_DIM_1]]] : tensor<?x?xf32>
|
||||
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[CF0]]) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
|
||||
// CHECK: return %[[CASTED]] : tensor<?x2xf32>
|
||||
func @aten.mm$basic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x2xf32> {
|
||||
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x2xf32>
|
||||
return %0 : tensor<?x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @aten.mm$no_convert$missing_dtype
|
||||
func @aten.mm$no_convert$missing_dtype(%arg0: tensor<*x!numpy.any_dtype>, %arg1: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK-NEXT: aten.mm
|
||||
%0 = "aten.mm"(%arg0, %arg1) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype>
|
||||
return %0 : tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @aten.mm$no_convert$wrong_rank
|
||||
func @aten.mm$no_convert$wrong_rank(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK-NEXT: aten.mm
|
||||
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
return %0 : tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @aten.mm$no_convert$result_missing_dtype
|
||||
func @aten.mm$no_convert$result_missing_dtype(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: aten.mm
|
||||
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
|
@ -24,6 +24,19 @@ func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<2x?xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// CHECK: %[[MM:.*]] = "aten.mm"(%[[LHS]], %[[RHS]]) : (tensor<2x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[MM]] : tensor<?x?xf32> to tensor<*x!numpy.any_dtype>
|
||||
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||
func @f(%arg0: tensor<2x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
%1 = "aten.mm"(%arg0, %arg1) : (tensor<2x?xf32>, tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
return %1 : tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||
// Check propagation through multiple ops.
|
||||
|
|
Loading…
Reference in New Issue