{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternGuards #-}
#if __GLASGOW_HASKELL__ >= 800
{-# LANGUAGE TemplateHaskellQuotes #-}
#else
{-# LANGUAGE TemplateHaskell #-}
#endif
#ifdef TRUSTWORTHY
# if MIN_VERSION_template_haskell(2,12,0)
{-# LANGUAGE Safe #-}
# else
{-# LANGUAGE Trustworthy #-}
# endif
#endif
module Control.Lens.Internal.FieldTH
( LensRules(..)
, FieldNamer
, DefName(..)
, ClassyNamer
, makeFieldOptics
, makeFieldOpticsForDec
, makeFieldOpticsForDec'
, HasFieldClasses
) where
import Prelude ()
import Control.Lens.At
import Control.Lens.Fold
import Control.Lens.Indexed
import Control.Lens.Internal.TH
import Control.Lens.Internal.Prelude
import Control.Lens.Lens
import Control.Lens.Plated
import Control.Lens.Prism
import Control.Lens.Setter
import Control.Lens.Getter
import Control.Lens.Tuple
import Control.Lens.Traversal
import Control.Monad
import Control.Monad.State
import Language.Haskell.TH.Lens
import Language.Haskell.TH
import qualified Language.Haskell.TH.Datatype as D
import qualified Language.Haskell.TH.Datatype.TyVarBndr as D
import Data.Maybe (fromMaybe,isJust,maybeToList)
import Data.List (nub)
import Data.Either (partitionEithers)
import Data.Semigroup (Any (..))
import Data.Set.Lens
import Data.Map ( Map )
import Data.Set ( Set )
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Traversable as T
makeFieldOptics :: LensRules -> Name -> DecsQ
makeFieldOptics rules = (`evalStateT` Set.empty) . makeFieldOpticsForDatatype rules <=< D.reifyDatatype
makeFieldOpticsForDec :: LensRules -> Dec -> DecsQ
makeFieldOpticsForDec rules = (`evalStateT` Set.empty) . makeFieldOpticsForDec' rules
makeFieldOpticsForDec' :: LensRules -> Dec -> HasFieldClasses [Dec]
makeFieldOpticsForDec' rules = makeFieldOpticsForDatatype rules <=< lift . D.normalizeDec
makeFieldOpticsForDatatype :: LensRules -> D.DatatypeInfo -> HasFieldClasses [Dec]
makeFieldOpticsForDatatype rules info =
do perDef <- lift $ do
fieldCons <- traverse normalizeConstructor cons
let allFields = toListOf (folded . _2 . folded . _1 . folded) fieldCons
let defCons = over normFieldLabels (expandName allFields) fieldCons
allDefs = setOf (normFieldLabels . folded . _1) defCons
T.sequenceA (Map.fromSet (buildScaffold rules s defCons) allDefs)
let defs = Map.toList perDef
case _classyLenses rules tyName of
Just (className, methodName) ->
makeClassyDriver rules className methodName s defs
Nothing -> do decss <- traverse (makeFieldOptic rules) defs
return (concat decss)
where
tyName = D.datatypeName info
s = datatypeTypeKinded info
cons = D.datatypeCons info
normFieldLabels :: Traversal [(Name,[(a,Type)])] [(Name,[(b,Type)])] a b
normFieldLabels = traverse . _2 . traverse . _1
expandName :: [Name] -> Maybe Name -> [(DefName, Maybe Name)]
expandName allFields mName = (\x -> (x, mName)) <$> (maybeToList mName >>= _fieldToDef rules tyName allFields)
normalizeConstructor ::
D.ConstructorInfo ->
Q (Name, [(Maybe Name, Type)])
normalizeConstructor con =
return (D.constructorName con,
zipWith checkForExistentials fieldNames (D.constructorFields con))
where
fieldNames =
case D.constructorVariant con of
D.RecordConstructor xs -> fmap Just xs
D.NormalConstructor -> repeat Nothing
D.InfixConstructor -> repeat Nothing
checkForExistentials _ fieldtype
| any (\tv -> D.tvName tv `Set.member` used) unallowable
= (Nothing, fieldtype)
where
used = setOf typeVars fieldtype
unallowable = D.constructorVars con
checkForExistentials fieldname fieldtype = (fieldname, fieldtype)
data OpticType = GetterType | LensType | IsoType
buildScaffold ::
LensRules ->
Type ->
[(Name, [([(DefName, Maybe Name)], Type)])] ->
DefName ->
Q (OpticType, OpticStab, [(Name, Int, [(Maybe Name, Int)])])
buildScaffold rules s cons defName =
do (s',t,a,b) <- buildStab s (concatMap snd (consForDef <&> _2 . mapped . _Right %~ snd))
let defType
| Just (_,cx,a') <- preview _ForallT a =
let optic | lensCase = getterTypeName
| otherwise = foldTypeName
in OpticSa cx optic s' a'
| not (_allowUpdates rules) =
let optic | lensCase = getterTypeName
| otherwise = foldTypeName
in OpticSa [] optic s' a
| _simpleLenses rules || s' == t && a == b =
let optic | isoCase && _allowIsos rules = iso'TypeName
| lensCase = lens'TypeName
| otherwise = traversal'TypeName
in OpticSa [] optic s' a
| otherwise =
let optic | isoCase && _allowIsos rules = isoTypeName
| lensCase = lensTypeName
| otherwise = traversalTypeName
in OpticStab optic s' t a b
opticType | has _ForallT a = GetterType
| not (_allowUpdates rules) = GetterType
| isoCase = IsoType
| otherwise = LensType
return (opticType, defType, scaffolds)
where
consForDef :: [(Name, [Either Type (Maybe Name, Type)])]
consForDef = over (mapped . _2 . mapped) categorize cons
scaffolds :: [(Name, Int, [(Maybe Name, Int)])]
scaffolds = [ (n, length ts, (\(a, b) -> (b, a)) <$> ts ^@.. folded <. _Right . _1) | (n,ts) <- consForDef ]
categorize :: ([(DefName, Maybe Name)], Type) -> Either Type (Maybe Name, Type)
categorize (defNames, t) =
case lookup defName defNames of
Just c -> Right (c, t)
Nothing -> Left t
lensCase :: Bool
lensCase = all (\x -> lengthOf (_2 . folded . _Right) x == 1) consForDef
isoCase :: Bool
isoCase = case scaffolds of
[(_,1,[(_, 0)])] -> True
_ -> False
data OpticStab = OpticStab Name Type Type Type Type
| OpticSa Cxt Name Type Type
stabToType :: OpticStab -> Type
stabToType (OpticStab c s t a b) = quantifyType [] (c `conAppsT` [s,t,a,b])
stabToType (OpticSa cx c s a ) = quantifyType cx (c `conAppsT` [s,a])
stabToContext :: OpticStab -> Cxt
stabToContext OpticStab{} = []
stabToContext (OpticSa cx _ _ _) = cx
stabToOptic :: OpticStab -> Name
stabToOptic (OpticStab c _ _ _ _) = c
stabToOptic (OpticSa _ c _ _) = c
stabToS :: OpticStab -> Type
stabToS (OpticStab _ s _ _ _) = s
stabToS (OpticSa _ _ s _) = s
stabToA :: OpticStab -> Type
stabToA (OpticStab _ _ _ a _) = a
stabToA (OpticSa _ _ _ a) = a
buildStab :: Type -> [Either Type Type] -> Q (Type,Type,Type,Type)
buildStab s categorizedFields =
do (subA,a) <- unifyTypes targetFields
let s' = applyTypeSubst subA s
sub <- T.sequenceA (Map.fromSet (newName . nameBase) unfixedTypeVars)
let (t,b) = over both (substTypeVars sub) (s',a)
return (s',t,a,b)
where
(fixedFields, targetFields) = partitionEithers categorizedFields
fixedTypeVars, unfixedTypeVars :: Set Name
fixedTypeVars = closeOverKinds $ setOf typeVars fixedFields
unfixedTypeVars = setOf typeVars s Set.\\ fixedTypeVars
kindVarsOfTvb :: D.TyVarBndr_ flag -> (Name, Set Name)
kindVarsOfTvb = D.elimTV (\n -> (n, Set.empty))
(\n k -> (n, setOf typeVars k))
sKindVarMap :: Map Name (Set Name)
sKindVarMap = Map.fromList $ map kindVarsOfTvb $ D.freeVariablesWellScoped [s]
lookupSKindVars :: Name -> Set Name
lookupSKindVars n = fromMaybe Set.empty $ Map.lookup n sKindVarMap
closeOverKinds :: Set Name -> Set Name
closeOverKinds st = foldl' Set.union Set.empty (Set.map lookupSKindVars st) `Set.union` st
makeFieldOptic ::
LensRules ->
(DefName, (OpticType, OpticStab, [(Name, Int, [(Maybe Name, Int)])])) ->
HasFieldClasses [Dec]
makeFieldOptic rules (defName, (opticType, defType, cons)) = do
locals <- get
addName
lift $ do cls <- mkCls locals
T.sequenceA (cls ++ sig ++ def)
where
mkCls locals = case defName of
MethodName c n | _generateClasses rules ->
do classExists <- isJust <$> lookupTypeName (show c)
return (if classExists || Set.member c locals then [] else [makeFieldClass defType c n])
_ -> return []
addName = case defName of
MethodName c _ -> addFieldClassName c
_ -> return ()
sig = case defName of
_ | not (_generateSigs rules) -> []
TopName n -> [sigD n (return (stabToType defType))]
MethodName{} -> []
fun n = funD n clauses : inlinePragma n
def = case defName of
TopName n -> fun n
MethodName c n -> [makeFieldInstance defType c (fun n)]
clauses = makeFieldClauses rules opticType cons
makeClassyDriver ::
LensRules ->
Name ->
Name ->
Type ->
[(DefName, (OpticType, OpticStab, [(Name, Int, [(Maybe Name, Int)])]))] ->
HasFieldClasses [Dec]
makeClassyDriver rules className methodName s defs = T.sequenceA (cls ++ inst)
where
cls | _generateClasses rules = [lift $ makeClassyClass className methodName s defs]
| otherwise = []
inst = [makeClassyInstance rules className methodName s defs]
makeClassyClass ::
Name ->
Name ->
Type ->
[(DefName, (OpticType, OpticStab, [(Name, Int, [(Maybe Name, Int)])]))] ->
DecQ
makeClassyClass className methodName s defs = do
let ss = map (stabToS . view (_2 . _2)) defs
(sub,s') <- unifyTypes (s : ss)
c <- newName "c"
let vars = D.freeVariablesWellScoped [s']
varNames = map D.tvName vars
fd | null vars = []
| otherwise = [FunDep [c] varNames]
classD (cxt[]) className (D.plainTV c:vars) fd
$ sigD methodName (return (lens'TypeName `conAppsT` [VarT c, s']))
: concat
[ [sigD defName (return ty)
,valD (varP defName) (normalB body) []
] ++
inlinePragma defName
| (TopName defName, (_, stab, _)) <- defs
, let body = appsE [varE composeValName, varE methodName, varE defName]
, let ty = quantifyType' (Set.fromList (c:varNames))
(stabToContext stab)
$ stabToOptic stab `conAppsT`
[VarT c, applyTypeSubst sub (stabToA stab)]
]
makeClassyInstance ::
LensRules ->
Name ->
Name ->
Type ->
[(DefName, (OpticType, OpticStab, [(Name, Int, [(Maybe Name, Int)])]))] ->
HasFieldClasses Dec
makeClassyInstance rules className methodName s defs = do
methodss <- traverse (makeFieldOptic rules') defs
lift $ instanceD (cxt[]) (return instanceHead)
$ valD (varP methodName) (normalB (varE idValName)) []
: map return (concat methodss)
where
instanceHead = className `conAppsT` (s : map tvbToType vars)
vars = D.freeVariablesWellScoped [s]
rules' = rules { _generateSigs = False
, _generateClasses = False
}
makeFieldClass :: OpticStab -> Name -> Name -> DecQ
makeFieldClass defType className methodName =
classD (cxt []) className [D.plainTV s, D.plainTV a] [FunDep [s] [a]]
[sigD methodName (return methodType)]
where
methodType = quantifyType' (Set.fromList [s,a])
(stabToContext defType)
$ stabToOptic defType `conAppsT` [VarT s,VarT a]
s = mkName "s"
a = mkName "a"
makeFieldInstance :: OpticStab -> Name -> [DecQ] -> DecQ
makeFieldInstance defType className decs =
containsTypeFamilies a >>= pickInstanceDec
where
s = stabToS defType
a = stabToA defType
containsTypeFamilies = go <=< D.resolveTypeSynonyms
where
go :: Type -> Q Bool
go (ConT nm) =
recover
(pure False)
(has (_FamilyI . _1 . _TypeFamilyD) <$> reify nm)
go ty = or <$> traverse go (ty ^.. plate)
_TypeFamilyD :: Getting Any Dec ()
_TypeFamilyD = _OpenTypeFamilyD.united <> _ClosedTypeFamilyD.united
pickInstanceDec hasFamilies
| hasFamilies = do
placeholder <- VarT <$> newName "a"
mkInstanceDec
[return (D.equalPred placeholder a)]
[s, placeholder]
| otherwise = mkInstanceDec [] [s, a]
mkInstanceDec context headTys =
instanceD (cxt context) (return (className `conAppsT` headTys)) decs
makeFieldClauses :: LensRules -> OpticType -> [(Name, Int, [(Maybe Name, Int)])] -> [ClauseQ]
makeFieldClauses rules opticType cons =
case opticType of
IsoType -> [ makeIsoClause conName | (conName, _, _) <- cons ]
GetterType -> [ makeGetterClause conName fieldCount (snd <$> fields)
| (conName, fieldCount, fields) <- cons ]
LensType -> [ makeFieldOpticClause conName fieldCount fields irref recSyn
| (conName, fieldCount, fields) <- cons ]
where
irref = _lazyPatterns rules
&& length cons == 1
recSyn = _recordSyntax rules && length cons == 1
makePureClause :: Name -> Int -> ClauseQ
makePureClause conName fieldCount =
do xs <- newNames "x" fieldCount
clause [wildP, conP conName (map varP xs)]
(normalB (appE (varE pureValName) (appsE (conE conName : map varE xs))))
[]
makeGetterClause :: Name -> Int -> [Int] -> ClauseQ
makeGetterClause conName fieldCount [] = makePureClause conName fieldCount
makeGetterClause conName fieldCount fields =
do f <- newName "f"
xs <- newNames "x" (length fields)
let pats (i:is) (y:ys)
| i `elem` fields = varP y : pats is ys
| otherwise = wildP : pats is (y:ys)
pats is _ = map (const wildP) is
fxs = [ appE (varE f) (varE x) | x <- xs ]
body = foldl (\a b -> appsE [varE apValName, a, b])
(appE (varE phantomValName) (head fxs))
(tail fxs)
clause [varP f, conP conName (pats [0..fieldCount - 1] xs)]
(normalB body)
[]
makeFieldOpticClause :: Name -> Int -> [(Maybe Name, Int)] -> Bool -> Bool -> ClauseQ
makeFieldOpticClause conName fieldCount [] _ _ =
makePureClause conName fieldCount
makeFieldOpticClause _ _ [(Just fieldName, _)] _ True =
do f <- newName "f"
r <- newName "r"
x <- newName "x"
let body = appsE [ [| fmap |]
, lamE [varP x] (recUpdE (varE r) [(,) fieldName <$> varE x])
, varE f `appE` (varE fieldName `appE` varE r)
]
clause [varP f, varP r] (normalB body) []
makeFieldOpticClause conName fieldCount ((_, field):fieldsWithNames) irref _ =
do f <- newName "f"
xs <- newNames "x" fieldCount
ys <- newNames "y" (1 + length fieldsWithNames)
let fields = snd <$> fieldsWithNames
xs' = foldr (\(i,x) -> set (ix i) x) xs (zip (field:fields) ys)
mkFx i = appE (varE f) (varE (xs !! i))
body0 = appsE [ varE fmapValName
, lamE (map varP ys) (appsE (conE conName : map varE xs'))
, mkFx field
]
body = foldl (\a b -> appsE [varE apValName, a, mkFx b]) body0 fields
let wrap = if irref then tildeP else id
clause [varP f, wrap (conP conName (map varP xs))]
(normalB body)
[]
makeIsoClause :: Name -> ClauseQ
makeIsoClause conName = clause [] (normalB (appsE [varE isoValName, destruct, construct])) []
where
destruct = do x <- newName "x"
lam1E (conP conName [varP x]) (varE x)
construct = conE conName
unifyTypes :: [Type] -> Q (Map Name Type, Type)
unifyTypes (x:xs) = foldM (uncurry unify1) (Map.empty, x) xs
unifyTypes [] = fail "unifyTypes: Bug: Unexpected empty list"
unify1 :: Map Name Type -> Type -> Type -> Q (Map Name Type, Type)
unify1 sub (VarT x) y
| Just r <- Map.lookup x sub = unify1 sub r y
unify1 sub x (VarT y)
| Just r <- Map.lookup y sub = unify1 sub x r
unify1 sub x y
| x == y = return (sub, x)
unify1 sub (AppT f1 x1) (AppT f2 x2) =
do (sub1, f) <- unify1 sub f1 f2
(sub2, x) <- unify1 sub1 x1 x2
return (sub2, AppT (applyTypeSubst sub2 f) x)
unify1 sub x (VarT y)
| elemOf typeVars y (applyTypeSubst sub x) =
fail "Failed to unify types: occurs check"
| otherwise = return (Map.insert y x sub, x)
unify1 sub (VarT x) y = unify1 sub y (VarT x)
unify1 sub (ForallT v1 [] t1) (ForallT v2 [] t2) =
do (sub1,t) <- unify1 sub t1 t2
v <- fmap nub (traverse (limitedSubst sub1) (v1++v2))
return (sub1, ForallT v [] t)
unify1 _ x y = fail ("Failed to unify types: " ++ show (x,y))
limitedSubst :: Map Name Type -> D.TyVarBndrSpec -> Q D.TyVarBndrSpec
limitedSubst sub tv
| Just r <- Map.lookup (D.tvName tv) sub =
case r of
VarT m -> limitedSubst sub (D.mapTVName (const m) tv)
_ -> fail "Unable to unify exotic higher-rank type"
| otherwise = return tv
applyTypeSubst :: Map Name Type -> Type -> Type
applyTypeSubst sub = rewrite aux
where
aux (VarT n) = Map.lookup n sub
aux _ = Nothing
data LensRules = LensRules
{ _simpleLenses :: Bool
, _generateSigs :: Bool
, _generateClasses :: Bool
, _allowIsos :: Bool
, _allowUpdates :: Bool
, _lazyPatterns :: Bool
, _recordSyntax :: Bool
, _fieldToDef :: FieldNamer
, _classyLenses :: ClassyNamer
}
type FieldNamer = Name
-> [Name]
-> Name
-> [DefName]
data DefName
= TopName Name
| MethodName Name Name
deriving (Show, Eq, Ord)
type ClassyNamer = Name
-> Maybe (Name, Name)
type HasFieldClasses = StateT (Set Name) Q
addFieldClassName :: Name -> HasFieldClasses ()
addFieldClassName n = modify $ Set.insert n