module Data.ASN1.Prim
(
ASN1(..)
, ASN1ConstructionType(..)
, encodeHeader
, encodePrimitiveHeader
, encodePrimitive
, decodePrimitive
, encodeConstructed
, encodeList
, encodeOne
, mkSmallestLength
, getBoolean
, getInteger
, getBitString
, getOctetString
, getUTF8String
, getNumericString
, getPrintableString
, getT61String
, getVideoTexString
, getIA5String
, getNull
, getOID
, getUTCTime
, getGeneralizedTime
, getGraphicString
, getVisibleString
, getGeneralString
, getUniversalString
, getCharacterString
, getBMPString
, putUTCTime
, putGeneralizedTime
, putInteger
, putBitString
, putString
, putOID
) where
import Data.ASN1.Internal
import Data.ASN1.Raw
import Data.ASN1.Stream
import Data.Bits
import Data.Word
import Data.List (unfoldr)
import Data.ByteString (ByteString)
import Data.Char (ord)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as T
import Data.Text.Lazy.Encoding (decodeASCII, decodeUtf8, decodeUtf32BE, encodeUtf8, encodeUtf32BE)
encodeUCS2BE :: Text -> L.ByteString
encodeUCS2BE t =
L.pack $ concatMap (\c -> let (d,m) = (fromEnum c) `divMod` 256 in [fromIntegral m,fromIntegral d] ) $ T.unpack t
decodeUCS2BE :: L.ByteString -> Text
decodeUCS2BE l = T.pack $ loop l
where
loop x
| L.null x = []
| otherwise =
let (h, r) = L.splitAt 2 l in
case L.unpack h of
[a,b] -> (toEnum $ (fromIntegral a) + (fromIntegral b) * 256) : loop r
_ -> loop r
encodeHeader :: Bool -> ASN1Length -> ASN1 -> ASN1Header
encodeHeader pc len (Boolean _) = ASN1Header Universal 0x1 pc len
encodeHeader pc len (IntVal _) = ASN1Header Universal 0x2 pc len
encodeHeader pc len (BitString _ _) = ASN1Header Universal 0x3 pc len
encodeHeader pc len (OctetString _) = ASN1Header Universal 0x4 pc len
encodeHeader pc len Null = ASN1Header Universal 0x5 pc len
encodeHeader pc len (OID _) = ASN1Header Universal 0x6 pc len
encodeHeader pc len (Real _) = ASN1Header Universal 0x9 pc len
encodeHeader pc len Enumerated = ASN1Header Universal 0xa pc len
encodeHeader pc len (UTF8String _) = ASN1Header Universal 0xc pc len
encodeHeader pc len (NumericString _) = ASN1Header Universal 0x12 pc len
encodeHeader pc len (PrintableString _) = ASN1Header Universal 0x13 pc len
encodeHeader pc len (T61String _) = ASN1Header Universal 0x14 pc len
encodeHeader pc len (VideoTexString _) = ASN1Header Universal 0x15 pc len
encodeHeader pc len (IA5String _) = ASN1Header Universal 0x16 pc len
encodeHeader pc len (UTCTime _) = ASN1Header Universal 0x17 pc len
encodeHeader pc len (GeneralizedTime _) = ASN1Header Universal 0x18 pc len
encodeHeader pc len (GraphicString _) = ASN1Header Universal 0x19 pc len
encodeHeader pc len (VisibleString _) = ASN1Header Universal 0x1a pc len
encodeHeader pc len (GeneralString _) = ASN1Header Universal 0x1b pc len
encodeHeader pc len (UniversalString _) = ASN1Header Universal 0x1c pc len
encodeHeader pc len (CharacterString _) = ASN1Header Universal 0x1d pc len
encodeHeader pc len (BMPString _) = ASN1Header Universal 0x1e pc len
encodeHeader pc len (Start Sequence) = ASN1Header Universal 0x10 pc len
encodeHeader pc len (Start Set) = ASN1Header Universal 0x11 pc len
encodeHeader pc len (Start (Container tc tag)) = ASN1Header tc tag pc len
encodeHeader pc len (Other tc tag _) = ASN1Header tc tag pc len
encodeHeader _ _ (End _) = error "this should not happen"
encodePrimitiveHeader :: ASN1Length -> ASN1 -> ASN1Header
encodePrimitiveHeader = encodeHeader False
encodePrimitiveData :: ASN1 -> ByteString
encodePrimitiveData (Boolean b) = B.singleton (if b then 0xff else 0)
encodePrimitiveData (IntVal i) = putInteger i
encodePrimitiveData (BitString i bits) = putBitString i bits
encodePrimitiveData (OctetString b) = putString b
encodePrimitiveData Null = B.empty
encodePrimitiveData (OID oid) = putOID oid
encodePrimitiveData (Real _) = B.empty
encodePrimitiveData Enumerated = B.empty
encodePrimitiveData (UTF8String b) = putString $ encodeUtf8 $ T.pack b
encodePrimitiveData (NumericString b) = putString b
encodePrimitiveData (PrintableString b) = putString $ encodeUtf8 $ T.pack b
encodePrimitiveData (T61String b) = putString $ encodeUtf8 $ T.pack b
encodePrimitiveData (VideoTexString b) = putString b
encodePrimitiveData (IA5String b) = putString $ encodeUtf8 $ T.pack b
encodePrimitiveData (UTCTime t) = putUTCTime t
encodePrimitiveData (GeneralizedTime t) = putGeneralizedTime t
encodePrimitiveData (GraphicString b) = putString b
encodePrimitiveData (VisibleString b) = putString b
encodePrimitiveData (GeneralString b) = putString b
encodePrimitiveData (UniversalString b) = putString $ encodeUtf32BE $ T.pack b
encodePrimitiveData (CharacterString b) = putString b
encodePrimitiveData (BMPString b) = putString $ encodeUCS2BE $ T.pack b
encodePrimitiveData (Other _ _ b) = b
encodePrimitiveData o = error ("not a primitive " ++ show o)
encodePrimitive :: ASN1 -> (Int, [ASN1Event])
encodePrimitive a =
let b = encodePrimitiveData a in
let blen = B.length b in
let len = makeLength blen in
let hdr = encodePrimitiveHeader len a in
(B.length (putHeader hdr) + blen, [Header hdr, Primitive b])
where
makeLength len
| len < 0x80 = LenShort len
| otherwise = LenLong (nbBytes len) len
nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1
encodeOne :: ASN1 -> (Int, [ASN1Event])
encodeOne (Start _) = error "encode one cannot do start"
encodeOne t = encodePrimitive t
encodeList :: [ASN1] -> (Int, [ASN1Event])
encodeList [] = (0, [])
encodeList (End _:xs) = encodeList xs
encodeList (t@(Start _):xs) =
let (ys, zs) = getConstructedEnd 0 xs in
let (llen, lev) = encodeList zs in
let (len, ev) = encodeConstructed t ys in
(llen + len, ev ++ lev)
encodeList (x:xs) =
let (llen, lev) = encodeList xs in
let (len, ev) = encodeOne x in
(llen + len, ev ++ lev)
encodeConstructed :: ASN1 -> [ASN1] -> (Int, [ASN1Event])
encodeConstructed c@(Start _) children =
let (clen, events) = encodeList children in
let len = mkSmallestLength clen in
let h = encodeHeader True len c in
let tlen = B.length (putHeader h) + clen in
(tlen, Header h : ConstructionBegin : events ++ [ConstructionEnd])
encodeConstructed _ _ = error "not a start node"
mkSmallestLength :: Int -> ASN1Length
mkSmallestLength i
| i < 0x80 = LenShort i
| otherwise = LenLong (nbBytes i) i
where nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1
type ASN1Ret = Either ASN1Err ASN1
decodePrimitive :: ASN1Header -> B.ByteString -> ASN1Ret
decodePrimitive (ASN1Header Universal 0x1 _ _) p = getBoolean False p
decodePrimitive (ASN1Header Universal 0x2 _ _) p = getInteger p
decodePrimitive (ASN1Header Universal 0x3 _ _) p = getBitString p
decodePrimitive (ASN1Header Universal 0x4 _ _) p = getOctetString p
decodePrimitive (ASN1Header Universal 0x5 _ _) p = getNull p
decodePrimitive (ASN1Header Universal 0x6 _ _) p = getOID p
decodePrimitive (ASN1Header Universal 0x7 _ _) _ = Left $ ASN1NotImplemented "Object Descriptor"
decodePrimitive (ASN1Header Universal 0x8 _ _) _ = Left $ ASN1NotImplemented "External"
decodePrimitive (ASN1Header Universal 0x9 _ _) _ = Left $ ASN1NotImplemented "real"
decodePrimitive (ASN1Header Universal 0xa _ _) _ = Left $ ASN1NotImplemented "enumerated"
decodePrimitive (ASN1Header Universal 0xb _ _) _ = Left $ ASN1NotImplemented "EMBEDDED PDV"
decodePrimitive (ASN1Header Universal 0xc _ _) p = getUTF8String p
decodePrimitive (ASN1Header Universal 0xd _ _) _ = Left $ ASN1NotImplemented "RELATIVE-OID"
decodePrimitive (ASN1Header Universal 0x10 _ _) _ = error "sequence not a primitive"
decodePrimitive (ASN1Header Universal 0x11 _ _) _ = error "set not a primitive"
decodePrimitive (ASN1Header Universal 0x12 _ _) p = getNumericString p
decodePrimitive (ASN1Header Universal 0x13 _ _) p = getPrintableString p
decodePrimitive (ASN1Header Universal 0x14 _ _) p = getT61String p
decodePrimitive (ASN1Header Universal 0x15 _ _) p = getVideoTexString p
decodePrimitive (ASN1Header Universal 0x16 _ _) p = getIA5String p
decodePrimitive (ASN1Header Universal 0x17 _ _) p = getUTCTime p
decodePrimitive (ASN1Header Universal 0x18 _ _) p = getGeneralizedTime p
decodePrimitive (ASN1Header Universal 0x19 _ _) p = getGraphicString p
decodePrimitive (ASN1Header Universal 0x1a _ _) p = getVisibleString p
decodePrimitive (ASN1Header Universal 0x1b _ _) p = getGeneralString p
decodePrimitive (ASN1Header Universal 0x1c _ _) p = getUniversalString p
decodePrimitive (ASN1Header Universal 0x1d _ _) p = getCharacterString p
decodePrimitive (ASN1Header Universal 0x1e _ _) p = getBMPString p
decodePrimitive (ASN1Header tc tag _ _) p = Right $ Other tc tag p
getBoolean :: Bool -> ByteString -> Either ASN1Err ASN1
getBoolean isDer s =
if B.length s == 1
then
case B.head s of
0 -> Right (Boolean False)
0xff -> Right (Boolean True)
_ -> if isDer then Left $ ASN1PolicyFailed "DER" "boolean value not canonical" else Right (Boolean True)
else Left $ ASN1Misc "boolean: length not within bound"
getInteger :: ByteString -> Either ASN1Err ASN1
getInteger s
| B.length s == 0 = Left $ ASN1Misc "integer: null encoding"
| B.length s == 1 = Right $ IntVal $ snd $ intOfBytes s
| otherwise =
if (v1 == 0xff && testBit v2 7) || (v1 == 0x0 && (not $ testBit v2 7))
then Left $ ASN1Misc "integer: not shortest encoding"
else Right $ IntVal $ snd $ intOfBytes s
where
v1 = s `B.index` 0
v2 = s `B.index` 1
getBitString :: ByteString -> Either ASN1Err ASN1
getBitString s =
let toSkip = B.head s in
let toSkip' = if toSkip >= 48 && toSkip <= 48 + 7 then toSkip (fromIntegral $ ord '0') else toSkip in
let xs = B.tail s in
if toSkip' >= 0 && toSkip' <= 7
then Right $ BitString (fromIntegral toSkip') (L.fromChunks [xs])
else Left $ ASN1Misc ("bitstring: skip number not within bound " ++ show toSkip' ++ " " ++ show s)
getString :: (ByteString -> Maybe ASN1Err) -> ByteString -> Either ASN1Err L.ByteString
getString check s =
case check s of
Nothing -> Right $ L.fromChunks [s]
Just err -> Left err
getOctetString :: ByteString -> Either ASN1Err ASN1
getOctetString = either Left (Right . OctetString) . getString (\_ -> Nothing)
getNumericString :: ByteString -> Either ASN1Err ASN1
getNumericString = either Left (Right . NumericString) . getString (\_ -> Nothing)
getPrintableString :: ByteString -> Either ASN1Err ASN1
getPrintableString = either Left (Right . PrintableString . T.unpack . decodeASCII) . getString (\_ -> Nothing)
getUTF8String :: ByteString -> Either ASN1Err ASN1
getUTF8String = either Left (Right . UTF8String . T.unpack . decodeUtf8) . getString (\_ -> Nothing)
getT61String :: ByteString -> Either ASN1Err ASN1
getT61String = either Left (Right . T61String . T.unpack . decodeASCII) . getString (\_ -> Nothing)
getVideoTexString :: ByteString -> Either ASN1Err ASN1
getVideoTexString = either Left (Right . VideoTexString) . getString (\_ -> Nothing)
getIA5String :: ByteString -> Either ASN1Err ASN1
getIA5String = either Left (Right . IA5String . T.unpack . decodeASCII) . getString (\_ -> Nothing)
getGraphicString :: ByteString -> Either ASN1Err ASN1
getGraphicString = either Left (Right . GraphicString) . getString (\_ -> Nothing)
getVisibleString :: ByteString -> Either ASN1Err ASN1
getVisibleString = either Left (Right . VisibleString) . getString (\_ -> Nothing)
getGeneralString :: ByteString -> Either ASN1Err ASN1
getGeneralString = either Left (Right . GeneralString) . getString (\_ -> Nothing)
getUniversalString :: ByteString -> Either ASN1Err ASN1
getUniversalString = either Left (Right . UniversalString . T.unpack . decodeUtf32BE) . getString (\_ -> Nothing)
getCharacterString :: ByteString -> Either ASN1Err ASN1
getCharacterString = either Left (Right . CharacterString) . getString (\_ -> Nothing)
getBMPString :: ByteString -> Either ASN1Err ASN1
getBMPString = either Left (Right . BMPString . T.unpack . decodeUCS2BE) . getString (\_ -> Nothing)
getNull :: ByteString -> Either ASN1Err ASN1
getNull s = if B.length s == 0 then Right Null else Left $ ASN1Misc "Null: data length not within bound"
getOID :: ByteString -> Either ASN1Err ASN1
getOID s = Right $ OID $ (fromIntegral (x `div` 40) : fromIntegral (x `mod` 40) : groupOID xs)
where
(x:xs) = B.unpack s
groupOID :: [Word8] -> [Integer]
groupOID = map (foldl (\acc n -> (acc `shiftL` 7) + fromIntegral n) 0) . groupSubOID
groupSubOIDHelper [] = Nothing
groupSubOIDHelper l = Just $ spanSubOIDbound l
groupSubOID :: [Word8] -> [[Word8]]
groupSubOID = unfoldr groupSubOIDHelper
spanSubOIDbound [] = ([], [])
spanSubOIDbound (a:as) = if testBit a 7 then (clearBit a 7 : ys, zs) else ([a], as)
where (ys, zs) = spanSubOIDbound as
getUTCTime :: ByteString -> Either ASN1Err ASN1
getUTCTime s =
case B.unpack s of
[y1, y2, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, z] ->
let y = integerise y1 y2 in
let year = 1900 + (if y <= 50 then y + 100 else y) in
let month = integerise m1 m2 in
let day = integerise d1 d2 in
let hour = integerise h1 h2 in
let minute = integerise mi1 mi2 in
let second = integerise s1 s2 in
Right $ UTCTime (year, month, day, hour, minute, second, z == 90)
_ -> Left $ ASN1Misc "utctime unexpected format"
where
integerise a b = ((fromIntegral a) (ord '0')) * 10 + ((fromIntegral b) (ord '0'))
getGeneralizedTime :: ByteString -> Either ASN1Err ASN1
getGeneralizedTime s =
case B.unpack s of
[y1, y2, y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, z] ->
let year = (integerise y1 y2) * 100 + (integerise y3 y4) in
let month = integerise m1 m2 in
let day = integerise d1 d2 in
let hour = integerise h1 h2 in
let minute = integerise mi1 mi2 in
let second = integerise s1 s2 in
Right $ GeneralizedTime (year, month, day, hour, minute, second, z == 90)
_ -> Left $ ASN1Misc "utctime unexpected format"
where
integerise a b = ((fromIntegral a) (ord '0')) * 10 + ((fromIntegral b) (ord '0'))
putTime :: Bool -> (Int, Int, Int, Int, Int, Int, Bool) -> ByteString
putTime generalized (y,m,d,h,mi,s,z) = B.pack etime
where
etime =
if generalized
then [y1, y2, y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, if z then 90 else 0 ]
else [y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, if z then 90 else 0 ]
split2 n = (fromIntegral $ n `div` 10 + ord '0', fromIntegral $ n `mod` 10 + ord '0')
((y1,y2),(y3,y4)) = (split2 (y `div` 100), split2 (y `mod` 100))
(m1, m2) = split2 m
(d1, d2) = split2 d
(h1, h2) = split2 h
(mi1, mi2) = split2 mi
(s1, s2) = split2 s
putUTCTime :: (Int, Int, Int, Int, Int, Int, Bool) -> ByteString
putUTCTime time = putTime False time
putGeneralizedTime :: (Int, Int, Int, Int, Int, Int, Bool) -> ByteString
putGeneralizedTime time = putTime True time
putInteger :: Integer -> ByteString
putInteger i = B.pack $ bytesOfInt i
putBitString :: Int -> L.ByteString -> ByteString
putBitString i bits = B.concat $ B.singleton (fromIntegral i) : L.toChunks bits
putString :: L.ByteString -> ByteString
putString l = B.concat $ L.toChunks l
putOID :: [Integer] -> ByteString
putOID oids = B.cons eoidclass subeoids
where
(oid1:oid2:suboids) = oids
eoidclass = fromIntegral (oid1 * 40 + oid2)
encode x | x == 0 = B.singleton 0
| otherwise = putVarEncodingIntegral x
subeoids = B.concat $ map encode suboids